From 397d552e3add75df763894df2da4c57301ad494f Mon Sep 17 00:00:00 2001 From: Billy Olsen Date: Mon, 14 Aug 2023 18:45:20 -0700 Subject: [PATCH] Update the schema to rename groups to ldapgroups Update the schema to rename groups to ldapgroups. This allows for the table name and means to access it to be consistent across all databases. Signed-off-by: Billy Olsen --- postgres.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/postgres.go b/postgres.go index d3d2e29..edcb781 100644 --- a/postgres.go +++ b/postgres.go @@ -2,6 +2,7 @@ package main import ( "database/sql" + "fmt" _ "github.com/lib/pq" @@ -50,9 +51,9 @@ CREATE TABLE IF NOT EXISTS users ( statement.Exec() statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_user_name on users(name)") statement.Exec() - statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS groups (id SERIAL PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") + statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS ldapgroups (id SERIAL PRIMARY KEY, name TEXT NOT NULL, gidnumber INTEGER NOT NULL)") statement.Exec() - statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on groups(name)") + statement, _ = db.Prepare("CREATE UNIQUE INDEX IF NOT EXISTS idx_group_name on ldapgroups(name)") statement.Exec() statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS includegroups (id SERIAL PRIMARY KEY, parentgroupid INTEGER NOT NULL, includegroupid INTEGER NOT NULL)") statement.Exec() @@ -66,4 +67,24 @@ func (b PostgresBackend) MigrateSchema(db *sql.DB, checker func(*sql.DB, string) statement, _ := db.Prepare("ALTER TABLE users ADD COLUMN sshkeys TEXT DEFAULT ''") statement.Exec() } + + if TableExists(db, "groups") { + // Drop the table created during schema creation + statement, _ := db.Prepare("DROP TABLE ldapgroups") + statement.Exec() + + statement, _ = db.Prepare("ALTER TABLE groups RENAME TO ldapgroups") + statement.Exec() + } +} + +// Indicates whether the table exists or not +func TableExists(db *sql.DB, tableName string) bool { + var found string + err := db.QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s", tableName)).Scan( + &found) + if err != nil { + return false + } + return true }