From f88ba56c09801ef246a26023e4ea59f838f9a175 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 9 Jun 2025 11:31:19 +0200 Subject: [PATCH 01/25] firewalldb: ensure that test SQL store is closed We add a helper function to the functions that creates the test SQL stores, in order to ensure that the store is properly closed when the test is cleaned up. --- firewalldb/actions_test.go | 9 --------- firewalldb/test_postgres.go | 4 ++-- firewalldb/test_sql.go | 16 ++++++++++++++-- firewalldb/test_sqlite.go | 6 +++--- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..324aea2c4 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -11,11 +11,11 @@ import ( // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..2f6c6e62e 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -7,6 +7,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -20,7 +21,7 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access @@ -36,7 +37,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +53,14 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { + store := NewSQLDB(sqlDB, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..506b49bcd 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -11,13 +11,13 @@ import ( // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) + return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + return createStore( + t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) } From ddd984c24caf4b4aca2aae79ed5f5e79a72efd78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 27 May 2025 16:52:50 +0200 Subject: [PATCH 02/25] firewalldb: export FirewallDBs interface In the upcoming migration of the firewall database to SQL, the helper functions that creates the test databases of different types, need to return a unified interface in order to not have to control the migration tests file by build tags. Therefore, we export the unified interface FirewallDBs, so that it can be returned public test DB creation functions --- firewalldb/db.go | 14 +++----------- firewalldb/interface.go | 8 ++++++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} From 8a152a456f67efe32e342496df84d62f177b5fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 19 May 2025 13:58:38 +0200 Subject: [PATCH 03/25] firewalldb: update NewTestDB funcs to return FirewallDBs In the upcoming migration of the firewall database to SQL, the helper functions that creates the test databases of different types, need to return a unified interface in order to not have to control the migration tests file by build tags. Therefore, we update the `NewTestDB` functions to return the `FirewallDBs` interface instead of the specific store implementation type. --- firewalldb/test_kvdb.go | 18 +++++++++++------- firewalldb/test_postgres.go | 4 ++-- firewalldb/test_sql.go | 5 ++--- firewalldb/test_sqlite.go | 4 ++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index 324aea2c4..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 2f6c6e62e..a412441f8 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -16,8 +16,7 @@ import ( // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -27,7 +26,7 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 506b49bcd..49b956d7d 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -10,13 +10,13 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { +func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { return createStore( t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, ) From ebf1e450d8cf93bda3da5ed144c19e5a85f9c598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 20 May 2025 10:39:39 +0200 Subject: [PATCH 04/25] db: add List All Kv Records query During the upcoming upcoming migration of the firewall database to SQL, we need to be able to check all kvstores records in the SQL database, to validate that the migration is successful in tests. This commits adds a query to list all kvstores records, which enables that functionality. --- db/sqlc/kvstores.sql.go | 36 ++++++++++++++++++++++++++++++++++++ db/sqlc/querier.go | 1 + db/sqlc/queries/kvstores.sql | 4 ++++ 3 files changed, 41 insertions(+) diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..c0949d173 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -257,6 +257,42 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, session_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.SessionID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..117a1fbc5 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..1ebfe3b0d 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -28,6 +28,10 @@ VALUES ($1, $2, $3, $4, $5, $6); DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores From 325f5a663f9a9c28932bc151f7060b085c413883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 30 Jun 2025 23:29:58 +0200 Subject: [PATCH 05/25] multi: rename sql kvstores session_id to group_id Rename the session_id to group_id in kvstores table in the SQL store, to better represent how the field is actually used. Note that this is a breaking change, and would normally require a new migration. But as the SQL store is not used in production, and only enabled under the dev build flag, we can rename it without a new migration, as there's no users of the SQL store in production. --- db/sqlc/kvstores.sql.go | 134 +++++++++++----------- db/sqlc/migrations/000003_kvstores.up.sql | 14 +-- db/sqlc/models.go | 2 +- db/sqlc/querier.go | 6 +- db/sqlc/queries/kvstores.sql | 26 ++--- firewalldb/kvstores_sql.go | 88 +++++++------- 6 files changed, 135 insertions(+), 135 deletions(-) diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index c0949d173..b46719eec 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -25,7 +25,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -33,7 +33,7 @@ type DeleteFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -42,7 +42,7 @@ func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -53,7 +53,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -68,28 +68,28 @@ func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGloba return err } -const deleteSessionKVStoreRecord = `-- name: DeleteSessionKVStoreRecord :exec +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id IS NULL ` -type DeleteSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, deleteSessionKVStoreRecord, +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } @@ -113,7 +113,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -121,7 +121,7 @@ type GetFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -130,7 +130,7 @@ func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVS arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) var value []byte @@ -144,7 +144,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -161,6 +161,35 @@ func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVSto return value, err } +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one INSERT INTO features (name) VALUES ($1) @@ -202,44 +231,15 @@ func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { return id, err } -const getSessionKVStoreRecord = `-- name: GetSessionKVStoreRecord :one -SELECT value -FROM kvstores -WHERE entry_key = $1 - AND rule_id = $2 - AND perm = $3 - AND session_id = $4 - AND feature_id IS NULL -` - -type GetSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 -} - -func (q *Queries) GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) { - row := q.db.QueryRowContext(ctx, getSessionKVStoreRecord, - arg.Key, - arg.RuleID, - arg.Perm, - arg.SessionID, - ) - var value []byte - err := row.Scan(&value) - return value, err -} - const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6) ` type InsertKVStoreRecordParams struct { Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte @@ -249,7 +249,7 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco _, err := q.db.ExecContext(ctx, insertKVStoreRecord, arg.Perm, arg.RuleID, - arg.SessionID, + arg.GroupID, arg.FeatureID, arg.EntryKey, arg.Value, @@ -258,7 +258,7 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco } const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many -SELECT id, perm, rule_id, session_id, feature_id, entry_key, value +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value FROM kvstores ` @@ -275,7 +275,7 @@ func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) &i.ID, &i.Perm, &i.RuleID, - &i.SessionID, + &i.GroupID, &i.FeatureID, &i.EntryKey, &i.Value, @@ -299,7 +299,7 @@ SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id = $6 ` @@ -308,7 +308,7 @@ type UpdateFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -318,7 +318,7 @@ func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -330,7 +330,7 @@ SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -351,31 +351,31 @@ func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGloba return err } -const updateSessionKVStoreRecord = `-- name: UpdateSessionKVStoreRecord :exec +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id IS NULL ` -type UpdateSessionKVStoreRecordParams struct { - Value []byte - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, updateSessionKVStoreRecord, +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, arg.Value, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } diff --git a/db/sqlc/migrations/000003_kvstores.up.sql b/db/sqlc/migrations/000003_kvstores.up.sql index d2f0653a5..e49ed9622 100644 --- a/db/sqlc/migrations/000003_kvstores.up.sql +++ b/db/sqlc/migrations/000003_kvstores.up.sql @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS features ( CREATE UNIQUE INDEX IF NOT EXISTS features_name_idx ON features (name); -- kvstores houses key-value pairs under various namespaces determined --- by the rule name, session ID, and feature name. +-- by the rule name, group ID, and feature name. CREATE TABLE IF NOT EXISTS kvstores ( -- The auto incrementing primary key. id INTEGER PRIMARY KEY, @@ -35,15 +35,15 @@ CREATE TABLE IF NOT EXISTS kvstores ( -- kv_store. rule_id BIGINT REFERENCES rules(id) NOT NULL, - -- The session ID that this kv_store belongs to. - -- If this is set, then this kv_store is a session-specific + -- The group ID that this kv_store belongs to. + -- If this is set, then this kv_store is a session-group specific -- kv_store for the given rule. - session_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, + group_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, -- The feature name that this kv_store belongs to. -- If this is set, then this kv_store is a feature-specific - -- kvstore under the given session ID and rule name. - -- If this is set, then session_id must also be set. + -- kvstore under the given group ID and rule name. + -- If this is set, then group_id must also be set. feature_id BIGINT REFERENCES features(id), -- The key of the entry. @@ -54,4 +54,4 @@ CREATE TABLE IF NOT EXISTS kvstores ( ); CREATE UNIQUE INDEX IF NOT EXISTS kvstores_lookup_idx - ON kvstores (entry_key, rule_id, perm, session_id, feature_id); + ON kvstores (entry_key, rule_id, perm, group_id, feature_id); diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 357360c9e..d19e66e10 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -63,7 +63,7 @@ type Kvstore struct { ID int64 Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 117a1fbc5..d76d5e6e3 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -16,7 +16,7 @@ type Querier interface { DeleteAllTempKVStores(ctx context.Context) error DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error DeleteSessionsWithState(ctx context.Context, state int16) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) @@ -29,6 +29,7 @@ type Querier interface { GetFeatureID(ctx context.Context, name string) (int64, error) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) @@ -40,7 +41,6 @@ type Querier interface { GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) - GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) @@ -71,7 +71,7 @@ type Querier interface { UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error } diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 1ebfe3b0d..6acc27468 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -21,7 +21,7 @@ FROM features WHERE name = sqlc.arg('name'); -- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6); -- name: DeleteAllTempKVStores :exec @@ -38,16 +38,16 @@ FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: GetSessionKVStoreRecord :one +-- name: GetGroupKVStoreRecord :one SELECT value FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: GetFeatureKVStoreRecord :one @@ -56,7 +56,7 @@ FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: DeleteGlobalKVStoreRecord :exec @@ -64,15 +64,15 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: DeleteSessionKVStoreRecord :exec +-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: DeleteFeatureKVStoreRecord :exec @@ -80,7 +80,7 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: UpdateGlobalKVStoreRecord :exec @@ -89,16 +89,16 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: UpdateSessionKVStoreRecord :exec +-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: UpdateFeatureKVStoreRecord :exec @@ -107,5 +107,5 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c3df2ddb..248892130 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -22,13 +22,13 @@ type SQLKVStoreQueries interface { DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlc.DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlc.DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg sqlc.DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlc.DeleteGroupKVStoreRecordParams) error GetFeatureKVStoreRecord(ctx context.Context, arg sqlc.GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg sqlc.GetGlobalKVStoreRecordParams) ([]byte, error) - GetSessionKVStoreRecord(ctx context.Context, arg sqlc.GetSessionKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlc.GetGroupKVStoreRecordParams) ([]byte, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlc.UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg sqlc.UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) @@ -198,7 +198,7 @@ func (s *sqlKVStore) Get(ctx context.Context, key string) ([]byte, error) { // // NOTE: part of the KVStore interface. func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, false) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, false) if err != nil { return err } @@ -219,7 +219,7 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { Value: value, Perm: s.params.perm, RuleID: ruleID, - SessionID: sessionID, + GroupID: groupID, FeatureID: featureID, }, ) @@ -233,26 +233,26 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { // Otherwise, the key exists but the value needs to be updated. switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.UpdateFeatureKVStoreRecord( ctx, sqlc.UpdateFeatureKVStoreRecordParams{ Key: key, Value: value, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.UpdateSessionKVStoreRecord( - ctx, sqlc.UpdateSessionKVStoreRecordParams{ - Key: key, - Value: value, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.UpdateGroupKVStoreRecord( + ctx, sqlc.UpdateGroupKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -278,7 +278,7 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // Note: we pass in true here for "read-only" since because this is a // Delete, if the record does not exist, we don't need to create one. // But no need to error out if it doesn't exist. - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if errors.Is(err, sql.ErrNoRows) || errors.Is(err, session.ErrUnknownGroup) { @@ -288,24 +288,24 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.DeleteFeatureKVStoreRecord( ctx, sqlc.DeleteFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.DeleteSessionKVStoreRecord( - ctx, sqlc.DeleteSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.DeleteGroupKVStoreRecord( + ctx, sqlc.DeleteGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -326,30 +326,30 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // get fetches the value under the given key from the underlying kv store given // the namespace fields. func (s *sqlKVStore) get(ctx context.Context, key string) ([]byte, error) { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if err != nil { return nil, err } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.GetSessionKVStoreRecord( - ctx, sqlc.GetSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.GetGroupKVStoreRecord( + ctx, sqlc.GetGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -373,7 +373,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, readOnly bool) (int64, sql.NullInt64, sql.NullInt64, error) { var ( - sessionID sql.NullInt64 + groupID sql.NullInt64 featureID sql.NullInt64 ruleID int64 err error @@ -382,8 +382,8 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, // If a group ID is specified, then we first check that this group ID // is a known session alias. s.params.groupID.WhenSome(func(id session.ID) { - var groupID int64 - groupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) + var dbGroupID int64 + dbGroupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) if errors.Is(err, sql.ErrNoRows) { err = session.ErrUnknownGroup @@ -392,20 +392,20 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, return } - sessionID = sql.NullInt64{ - Int64: groupID, + groupID = sql.NullInt64{ + Int64: dbGroupID, Valid: true, } }) if err != nil { - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } // We only insert a new rule name into the DB if this is a write call. if readOnly { ruleID, err = s.queries.GetRuleID(ctx, s.params.ruleName) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get rule ID: %w", err) } } else { @@ -413,7 +413,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, ctx, s.params.ruleName, ) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get or insert rule "+ "ID: %w", err) } @@ -441,5 +441,5 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, } }) - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } From 0c53b7b91b8aca752658ffbc08ddb0c779e1f0ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 6 May 2025 19:42:08 +0200 Subject: [PATCH 06/25] firewalldb: clarify bbolt kvstores illustration During the migration of the kvstores to SQL, we'll iterate over the buckets in the bbolt database, which holds all kvstores records. In order to understand why the migration iterates over the buckets in the specific order, we need to clarify the bbolt kvstores illustration docs, so that it correctly reflects how the records are actually stored in the bbolt database. --- firewalldb/kvstores_kvdb.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..78676e3ed 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} */ var ( From c3d8ecf4785aad20e06b55f3c35ce2c4d4138ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 6 May 2025 19:44:31 +0200 Subject: [PATCH 07/25] firewalldb: add kvstores kvdb to SQL migration This commit introduces the migration logic for transitioning the kvstores store from kvdb to SQL. Note that as of this commit, the migration is not yet triggered by any production code, i.e. only tests execute the migration logic. --- firewalldb/sql_migration.go | 492 +++++++++++++++++++++++ firewalldb/sql_migration_test.go | 664 +++++++++++++++++++++++++++++++ 2 files changed, 1156 insertions(+) create mode 100644 firewalldb/sql_migration.go create mode 100644 firewalldb/sql_migration_test.go diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..1e114c12c --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,492 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvEntry represents a single KV entry inserted into the BoltDB. +type kvEntry struct { + perm bool + ruleName string + key string + value []byte + + // groupAlias is the legacy session group alias that the entry is + // associated with. For global entries, this will be fn.None[[]byte]. + groupAlias fn.Option[[]byte] + + // featureName is the name of the feature that the entry is associated + // with. If the entry is not feature specific, this will be + // fn.None[string]. + featureName fn.Option[string] +} + +// sqlKvEntry represents a single KV entry inserted into the SQL DB, containing +// the same fields as the kvEntry, but with additional fields that represent the +// SQL IDs of the rule, session group, and feature. +type sqlKvEntry struct { + *kvEntry + + ruleID int64 + + // groupID is the sql session group ID that the entry is associated + // with. For global entries, this will be Valid=false. + groupID sql.NullInt64 + + // featureID is the sql feature ID that the entry is associated with. + // This is only set if the entry is feature specific, and will be + // Valid=false for other types entries. If this is set, then groupID + // will also be set. + featureID sql.NullInt64 +} + +// namespacedKey returns a string representation of the kvEntry purely used for +// logging purposes. +func (e *kvEntry) namespacedKey() string { + ns := fmt.Sprintf("perm: %t, rule: %s", e.perm, e.ruleName) + + e.groupAlias.WhenSome(func(alias []byte) { + ns += fmt.Sprintf(", group: %s", alias) + }) + + e.featureName.WhenSome(func(feature string) { + ns += fmt.Sprintf(", feature: %s", feature) + }) + + ns += fmt.Sprintf(", key: %s", e.key) + + return ns +} + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, sqlTx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + var pairs []*kvEntry + + // 1) Collect all key-value pairs from the KV store. + err := kvStore.View(func(tx *bbolt.Tx) error { + var err error + pairs, err = collectAllPairs(tx) + return err + }) + if err != nil { + return fmt.Errorf("collecting all kv pairs failed: %w", err) + } + + var insertedPairs []*sqlKvEntry + + // 2) Insert all collected key-value pairs into the SQL database. + for _, entry := range pairs { + insertedPair, err := insertPair(ctx, sqlTx, entry) + if err != nil { + return fmt.Errorf("inserting kv pair %v failed: %w", + entry.key, err) + } + + insertedPairs = append(insertedPairs, insertedPair) + } + + // 3) Validate the migrated values against the original values. + for _, insertedPair := range insertedPairs { + // Fetch the appropriate SQL entry's value. + migratedValue, err := getSQLValue(ctx, sqlTx, insertedPair) + if err != nil { + return fmt.Errorf("getting SQL value for key %s "+ + "failed: %w", insertedPair.namespacedKey(), err) + } + + // Compare the value of the migrated entry with the original + // value from the KV store. + // NOTE: if the insert a []byte{} value into the sqldb as the + // entry value, and then retrieve it, the value will be + // returned as nil. The bytes.Equal will pass in that case, + // and therefore such cases won't error out. The kvdb instance + // can store []byte{} values. + if !bytes.Equal(migratedValue, insertedPair.value) { + return fmt.Errorf("migrated value for key %s "+ + "does not match original value: "+ + "migrated %x, original %x", + insertedPair.namespacedKey(), migratedValue, + insertedPair.value) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(pairs)) + + return nil +} + +// collectAllPairs collects all key-value pairs from the KV store, and returns +// them as a slice of kvEntry structs. The function expects the KV store to be +// stuctured as described in the comment in the firewalldb/kvstores_kvdb.go +// file. Any other structure will result in an error. +// Note that this function and the subsequent functions are intentionally +// designed to iterate over all buckets and values that exist in the KV store. +// That ensures that we find all stores and values that exist in the KV store, +// and can be sure that the kv store actually follows the expected structure. +func collectAllPairs(tx *bbolt.Tx) ([]*kvEntry, error) { + var entries []*kvEntry + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(tx, false, perm) + if err != nil { + return nil, err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no entries + // to migrate under that bucket, therefore we don't + // error, and just proceed to not migrate any entries + // under that bucket. + continue + } + + // Loop over each rule-name bucket. + err = mainBucket.ForEach(func(rule, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleBucket := mainBucket.Bucket(rule) + if ruleBucket == nil { + return fmt.Errorf("rule bucket %s not found", + rule) + } + + pairs, err := collectRulePairs( + ruleBucket, perm, string(rule), + ) + if err != nil { + return err + } + + entries = append(entries, pairs...) + + return nil + }) + if err != nil { + return nil, err + } + } + + return entries, nil +} + +// collectRulePairs processes a single rule bucket, which should contain the +// global and session-kv-store key buckets. +func collectRulePairs(bkt *bbolt.Bucket, perm bool, rule string) ([]*kvEntry, + error) { + + var params []*kvEntry + + err := verifyBktKeys( + bkt, true, globalKVStoreBucketKey, sessKVStoreBucketKey, + ) + if err != nil { + return params, fmt.Errorf("verifying rule bucket %s keys "+ + "failed: %w", rule, err) + } + + if globalBkt := bkt.Bucket(globalKVStoreBucketKey); globalBkt != nil { + p, err := collectKVPairs( + globalBkt, true, perm, rule, + fn.None[[]byte](), fn.None[string](), + ) + if err != nil { + return nil, fmt.Errorf("collecting global kv pairs "+ + "failed: %w", err) + } + + params = append(params, p...) + } + + if sessBkt := bkt.Bucket(sessKVStoreBucketKey); sessBkt != nil { + err := sessBkt.ForEach(func(groupAlias, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets "+ + "under %s bucket", sessKVStoreBucketKey) + } + + groupBucket := sessBkt.Bucket(groupAlias) + if groupBucket == nil { + return fmt.Errorf("group bucket for group "+ + "alias %s not found", groupAlias) + } + + kvPairs, err := collectKVPairs( + groupBucket, false, perm, rule, + fn.Some(groupAlias), fn.None[string](), + ) + if err != nil { + return fmt.Errorf("collecting group kv "+ + "pairs failed: %w", err) + } + + params = append(params, kvPairs...) + + err = verifyBktKeys( + groupBucket, false, featureKVStoreBucketKey, + ) + if err != nil { + return fmt.Errorf("verification of group "+ + "bucket %s keys failed: %w", groupAlias, + err) + } + + ftBkt := groupBucket.Bucket(featureKVStoreBucketKey) + if ftBkt == nil { + return nil + } + + return ftBkt.ForEach(func(ftName, v []byte) error { + if v != nil { + return fmt.Errorf("expected only "+ + "buckets under %s bucket", + featureKVStoreBucketKey) + } + + // The feature name should exist, as per the + // verification above. + featureBucket := ftBkt.Bucket(ftName) + if featureBucket == nil { + return fmt.Errorf("feature bucket "+ + "%s not found", ftName) + } + + featurePairs, err := collectKVPairs( + featureBucket, true, perm, rule, + fn.Some(groupAlias), + fn.Some(string(ftName)), + ) + if err != nil { + return fmt.Errorf("collecting "+ + "feature kv pairs failed: %w", + err) + } + + params = append(params, featurePairs...) + + return nil + }) + }) + if err != nil { + return nil, fmt.Errorf("collecting session kv pairs "+ + "failed: %w", err) + } + } + + return params, nil +} + +// collectKVPairs collects all key-value pairs from the given bucket, and +// returns them as a slice of kvEntry structs. If the errorOnBuckets parameter +// is set to true, then the function will return an error if the bucket +// contains any sub-buckets. Note that when the errorOnBuckets parameter is +// set to false, the function will not collect any key-value pairs from the +// sub-buckets, and will just ignore them. +func collectKVPairs(bkt *bbolt.Bucket, errorOnBuckets, perm bool, + ruleName string, groupAlias fn.Option[[]byte], + featureName fn.Option[string]) ([]*kvEntry, error) { + + var params []*kvEntry + + return params, bkt.ForEach(func(key, value []byte) error { + // If the value is nil, then this is a bucket, which we + // don't want to process here, as we only want to collect + // the key-value pairs, not the buckets. If we should + // error on buckets, then we return an error here. + if value == nil { + if errorOnBuckets { + return fmt.Errorf("unexpected bucket %s found "+ + "in when collecting kv pairs", key) + } + + return nil + } + + params = append(params, &kvEntry{ + perm: perm, + ruleName: ruleName, + key: string(key), + featureName: featureName, + groupAlias: groupAlias, + value: value, + }) + + return nil + }) +} + +// insertPair inserts a single key-value pair into the SQL database. +func insertPair(ctx context.Context, tx SQLQueries, + entry *kvEntry) (*sqlKvEntry, error) { + + ruleID, err := tx.GetOrInsertRuleID(ctx, entry.ruleName) + if err != nil { + return nil, err + } + + p := sqlc.InsertKVStoreRecordParams{ + Perm: entry.perm, + RuleID: ruleID, + EntryKey: entry.key, + Value: entry.value, + } + + entry.groupAlias.WhenSome(func(alias []byte) { + var groupID int64 + groupID, err = tx.GetSessionIDByAlias(ctx, alias) + if err != nil { + err = fmt.Errorf("getting group id by alias %x "+ + "failed: %w", alias, err) + return + } + + p.GroupID = sqldb.SQLInt64(groupID) + }) + if err != nil { + return nil, err + } + + entry.featureName.WhenSome(func(feature string) { + var featureID int64 + featureID, err = tx.GetOrInsertFeatureID(ctx, feature) + if err != nil { + err = fmt.Errorf("getting/inserting feature id for %s "+ + "failed: %w", feature, err) + return + } + + p.FeatureID = sqldb.SQLInt64(featureID) + }) + if err != nil { + return nil, err + } + + err = tx.InsertKVStoreRecord(ctx, p) + if err != nil { + return nil, err + } + + return &sqlKvEntry{ + kvEntry: entry, + ruleID: p.RuleID, + groupID: p.GroupID, + featureID: p.FeatureID, + }, nil +} + +// getSQLValue retrieves the key value for the given kvEntry from the SQL +// database. +func getSQLValue(ctx context.Context, tx SQLQueries, + entry *sqlKvEntry) ([]byte, error) { + + switch { + case entry.featureID.Valid && entry.groupID.Valid: + return tx.GetFeatureKVStoreRecord( + ctx, sqlc.GetFeatureKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + FeatureID: entry.featureID, + Key: entry.key, + }, + ) + case entry.groupID.Valid: + return tx.GetGroupKVStoreRecord( + ctx, sqlc.GetGroupKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + Key: entry.key, + }, + ) + case !entry.featureID.Valid && !entry.groupID.Valid: + return tx.GetGlobalKVStoreRecord( + ctx, sqlc.GetGlobalKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + Key: entry.key, + }, + ) + default: + return nil, fmt.Errorf("invalid combination of feature and "+ + "session ID: featureID valid: %v, groupID valid: %v", + entry.featureID.Valid, entry.groupID.Valid) + } +} + +// verifyBktKeys checks that the given bucket only contains buckets with the +// passed keys, and optionally also key-value pairs. If the errorOnKeyValues +// parameter is set to true, the function will error if it finds key-value pairs +// in the bucket. +func verifyBktKeys(bkt *bbolt.Bucket, errorOnKeyValues bool, + keys ...[]byte) error { + + return bkt.ForEach(func(key, v []byte) error { + if v != nil { + // If we allow key-values, then we can just continue + // to the next key. Else we need to error out, as we + // only expect buckets under the passed bucket. + if errorOnKeyValues { + return fmt.Errorf("unexpected key-value pair "+ + "found: key=%s, value=%x", key, v) + } + + return nil + } + + for _, expectedKey := range keys { + if bytes.Equal(key, expectedKey) { + // If this is an expected key, we can continue + // to the next key. + return nil + } + } + + return fmt.Errorf("unexpected key found: %s", key) + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..1298e3e53 --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,664 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +const ( + testRuleName = "test-rule" + testRuleName2 = "test-rule-2" + testFeatureName = "test-feature" + testFeatureName2 = "test-feature-2" + testEntryKey = "test-entry-key" + testEntryKey2 = "test-entry-key-2" + testEntryKey3 = "test-entry-key-3" + testEntryKey4 = "test-entry-key-4" +) + +var ( + testEntryValue = []byte{1, 2, 3} +) + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backend to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, + *db.TransactionExecutor[SQLQueries]) { + + testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + genericExecutor := db.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return baseDB.WithTx(tx) + }, + ) + + return store, genericExecutor + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores entries in the SQLDB match the original kv + // stores entries in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + kvEntries []*kvEntry) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = sqlStore.GetRuleID(ctx, ruleName) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = sqlStore.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = sqlStore.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv entries from the SQLDB, + // in order to be able to compare them to the original kv + // entries, to ensure that the migration was successful. + sqlKvEntries, err := sqlStore.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvEntries), len(sqlKvEntries)) + + // We then iterate over the original kv entries that were + // migrated from the BoltDB to the SQLDB, and assert that they + // match the migrated SQL kv entries. + // NOTE: when fetching kv entries that were inserted into the + // sql store with the entry value []byte{}, a nil value is + // returned. Therefore, require.Equal would error on such cases, + // while bytes.Equal would not. Therefore, the comparison below + // uses bytes.Equal to compare the values. + for _, entry := range kvEntries { + ruleID := getRuleID(entry.ruleName) + + if entry.groupAlias.IsNone() { + sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + ctx, + sqlc.GetGlobalKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } else if entry.featureName.IsNone() { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + + v, err := sqlStore.GetGroupKVStoreRecord( + ctx, + sqlc.GetGroupKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, v), + ) + } else { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + featureID := getFeatureID( + entry.featureName.UnwrapOrFail(t), + ) + + sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + ctx, + sqlc.GetFeatureKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + // Don't populate the DB. + return make([]*kvEntry, 0) + }, + }, + { + name: "global entries", + populateDB: globalEntries, + }, + { + name: "session specific entries", + populateDB: sessionSpecificEntries, + }, + { + name: "feature specific entries", + populateDB: featureSpecificEntries, + }, + { + name: "all entry combinations", + populateDB: allEntryCombinations, + }, + { + name: "random entries", + populateDB: randomKVEntries, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + entries := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + sqlStore, txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLQueries) error { + return MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + }, + ) + require.NoError(t, err) + + // Assert migration results. + assertMigrationResults(t, sqlStore, entries) + }) + } +} + +// globalEntries populates the kv store with one global entry for the temp +// store, and one for the perm store. +func globalEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + _ session.Store) []*kvEntry { + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + ) +} + +// sessionSpecificEntries populates the kv store with one session specific +// entry for the local temp store, and one session specific entry for the perm +// local store. +func sessionSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, fn.None[string](), + testEntryKey, testEntryValue, + ) +} + +// featureSpecificEntries populates the kv store with one feature specific +// entry for the local temp store, and one feature specific entry for the perm +// local store. +func featureSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + ) +} + +// allEntryCombinations adds all types of different entries at all possible +// levels of the kvstores, including multple entries with the same +// ruleName, groupAlias and featureName. The test aims to cover all possible +// combinations of entries in the kvstores, including nil and empty entry +// values. That therefore ensures that the migrations don't overwrite or miss +// any entries when the entry set is more complex than just a single entry at +// each level. +func allEntryCombinations(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + var result []*kvEntry + add := func(entry []*kvEntry) { + result = append(result, entry...) + } + + // First lets create standard entries at all levels, which represents + // the entries added by other tests. + add(globalEntries(t, ctx, boltDB, sessionStore)) + add(sessionSpecificEntries(t, ctx, boltDB, sessionStore)) + add(featureSpecificEntries(t, ctx, boltDB, sessionStore)) + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + // Now lets add a few more entries at with different rule names and + // features, just to ensure that we cover entries in different rule and + // feature tables. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + )) + // Let's also create an entry with a different feature name that's still + // referencing the same group ID as the previous entry. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName2), testEntryKey, testEntryValue, + )) + + // Finally, lets add a few entries with nil and empty values set for the + // actual key value, at all different levels, to ensure that tests don't + // break if the value is nil or empty. + var ( + nilValue []byte = nil + nilSliceValue = []byte(nil) + emptyValue = []byte{} + ) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey4, emptyValue, + )) + + return result +} + +func getNewSessionAlias(t *testing.T, ctx context.Context, + sessionStore session.Store) fn.Option[[]byte] { + + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + return fn.Some(sess.GroupID[:]) +} + +// insertTempAndPermEntry populates the kv store with one entry for the temp +// store, and one entry for the perm store. Both of the entries will be inserted +// with the same groupAlias, ruleName, entryKey and entryValue. +func insertTempAndPermEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, ruleName string, groupAlias fn.Option[[]byte], + featureNameOpt fn.Option[string], entryKey string, + entryValue []byte) []*kvEntry { + + tempKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: false, + } + + insertKvEntry(t, ctx, boltDB, tempKvEntry) + + permKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: true, + } + + insertKvEntry(t, ctx, boltDB, permKvEntry) + + return []*kvEntry{tempKvEntry, permKvEntry} +} + +// insertKvEntry populates the kv store with passed entry, and asserts that the +// entry is inserted correctly. +func insertKvEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, entry *kvEntry) { + + if entry.groupAlias.IsNone() && entry.featureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + // We get the kv stores that the entry will be inserted into. Note that + // we set an empty group ID if the entry is global, as the group ID + // will not be used when fetching the actual kv store that's used for + // global entries. + groupID := [4]byte{} + if entry.groupAlias.IsSome() { + copy(groupID[:], entry.groupAlias.UnwrapOrFail(t)) + } + + kvStores := boltDB.GetKVStores( + entry.ruleName, groupID, entry.featureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + store := tx.Global() + + switch { + case entry.groupAlias.IsNone() && !entry.perm: + store = tx.GlobalTemp() + case entry.groupAlias.IsSome() && !entry.perm: + store = tx.LocalTemp() + case entry.groupAlias.IsSome() && entry.perm: + store = tx.Local() + } + + return store.Set(ctx, entry.key, entry.value) + }) + require.NoError(t, err) +} + +// randomKVEntries populates the kv store with random kv entries that span +// across all possible combinations of different levels of entries in the kv +// store. All values and different bucket names are randomly generated. +func randomKVEntries(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry { + + var ( + // We set the number of entries to insert to 1000, as that + // should be enough to cover as many different + // combinations of entries as possible, while still being + // fast enough to run in a reasonable time. + numberOfEntries = 1000 + insertedEntries = make([]*kvEntry, 0) + ruleName = "initial-rule" + groupAlias []byte + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(1), time.Unix(1000, 0), + "serverAddr.test", + ) + require.NoError(t, err) + + groupAlias = sess.GroupID[:] + + // Generate random entries. Note that many entries will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfEntries; i++ { + // On average, we will generate a new rule which will be used + // for the kv store entry 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global entries, we will generate a new group + // alias 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupAlias = newSess.GroupID[:] + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global entries, we insert a feature + // specific entry. The other 50% will be session specific + // entries. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + groupAliasOpt := fn.None[[]byte]() + if !global { + // If the entry is not global, we set the group ID + // to the latest session's group ID. + groupAliasOpt = fn.Some(groupAlias[:]) + } + + entry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAliasOpt, + featureName: featureNameOpt, + key: fmt.Sprintf("key-%d", i), + perm: perm, + } + + // When setting a value for the entry, 25% of the time, we will + // set a nil or empty value. + if rand.Intn(4) == 0 { + // in 50% of these cases, we will set the value to nil, + // and in the other 50% we will set it to an empty + // value + if rand.Intn(2) == 0 { + entry.value = nil + } else { + entry.value = []byte{} + } + } else { + // Else generate a random value for all entries, + entry.value = []byte(randomString(rand.Intn(100) + 1)) + } + + // Insert the entry into the kv store. + insertKvEntry(t, ctx, boltDB, entry) + + // Add the entry to the list of inserted entries. + insertedEntries = append(insertedEntries, entry) + } + + return insertedEntries +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} From 3d3960da39679f015b10091c66a9b0f9b736d3b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Wed, 4 Jun 2025 19:17:45 +0200 Subject: [PATCH 08/25] mod: go get sqldb/v2 --- go.mod | 12 +++++++++--- go.sum | 12 ++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index c0a0c05c5..cdcb07912 100644 --- a/go.mod +++ b/go.mod @@ -34,14 +34,15 @@ require ( github.com/lightninglabs/pool/poolrpc v1.0.1 github.com/lightninglabs/taproot-assets v0.6.0-rc3 github.com/lightninglabs/taproot-assets/taprpc v1.0.6 - github.com/lightningnetwork/lnd v0.19.1-beta.rc1 + github.com/lightningnetwork/lnd v0.19.1-beta github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 github.com/lightningnetwork/lnd/fn v1.2.3 github.com/lightningnetwork/lnd/fn/v2 v2.0.8 github.com/lightningnetwork/lnd/kvdb v1.4.16 - github.com/lightningnetwork/lnd/sqldb v1.0.9 - github.com/lightningnetwork/lnd/tlv v1.3.1 + github.com/lightningnetwork/lnd/sqldb v1.0.10 + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 + github.com/lightningnetwork/lnd/tlv v1.3.2 github.com/lightningnetwork/lnd/tor v1.1.6 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f github.com/mwitkow/grpc-proxy v0.0.0-20230212185441-f345521cb9c9 @@ -247,3 +248,8 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // it is a replace in the tapd repository, it doesn't get propagated here // automatically, so we need to add it manually. replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate/v4 v4.18.2-9023d66a-fork-pr-2 + +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb diff --git a/go.sum b/go.sum index 36adbb9e0..df0a58695 100644 --- a/go.sum +++ b/go.sum @@ -616,6 +616,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:Gbu github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b h1:2UbeVKTtFvxxtHTzHOmQJvBeQ87eTdYMrAs4O8ZuIi8= +github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b/go.mod h1:FLpPqYTU7KKJ87mNsqp/DsJwmpenK0zhuLBO+beWCl4= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -1178,8 +1180,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.6 h1:h8tf4y7U5/3A9WNAs7HBTL8 github.com/lightninglabs/taproot-assets/taprpc v1.0.6/go.mod h1:vOM2Ap2wYhEZjiJU7bNNg+e5tDxkvRAuyXwf/KQ4tgo= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.1-beta.rc1 h1:VV7xpS1g7OpDhlGYIa50Ac4I7TDZhvHjZ2Mmf+L4a7Y= -github.com/lightningnetwork/lnd v0.19.1-beta.rc1/go.mod h1:iHZ/FHFK00BqV6qgDkZZfqWE3LGtgE0U5KdO5WrM+eQ= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= @@ -1194,12 +1194,12 @@ github.com/lightningnetwork/lnd/kvdb v1.4.16 h1:9BZgWdDfjmHRHLS97cz39bVuBAqMc4/p github.com/lightningnetwork/lnd/kvdb v1.4.16/go.mod h1:HW+bvwkxNaopkz3oIgBV6NEnV4jCEZCACFUcNg4xSjM= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= -github.com/lightningnetwork/lnd/sqldb v1.0.9 h1:7OHi+Hui823mB/U9NzCdlZTAGSVdDCbjp33+6d/Q+G0= -github.com/lightningnetwork/lnd/sqldb v1.0.9/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4= +github.com/lightningnetwork/lnd/sqldb v1.0.10 h1:ZLV7TGwjnKupVfCd+DJ43MAc9BKVSFCnvhpSPGKdN3M= +github.com/lightningnetwork/lnd/sqldb v1.0.10/go.mod h1:c/vWoQfcxu6FAfHzGajkIQi7CEIeIZFhhH4DYh1BJpc= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.3.1 h1:o7CZg06y+rJZfUMAo0WzBLr0pgBWCzrt0f9gpujYUzk= -github.com/lightningnetwork/lnd/tlv v1.3.1/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= +github.com/lightningnetwork/lnd/tlv v1.3.2 h1:MO4FCk7F4k5xPMqVZF6Nb/kOpxlwPrUQpYjmyKny5s0= +github.com/lightningnetwork/lnd/tlv v1.3.2/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= github.com/lightningnetwork/lnd/tor v1.1.6 h1:WHUumk7WgU6BUFsqHuqszI9P6nfhMeIG+rjJBlVE6OE= github.com/lightningnetwork/lnd/tor v1.1.6/go.mod h1:qSRB8llhAK+a6kaTPWOLLXSZc6Hg8ZC0mq1sUQ/8JfI= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= From 2820392ebd593c67ef21737bd036502965201524 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Tue, 10 Jun 2025 17:48:09 +0200 Subject: [PATCH 09/25] multi: use sqldb v2 in litd This commit updates litd to use the new sqldb v2 package. Note that this with just this commit, litd will not utilize the capabilities of sqldb v2 to run specific post migrations steps (such as migrating the kvdb to SQL). That functionality will be added in later commits. Instead, this commit just focuses on adding support for the new sqldb v2 package, and the functionality of the SQL stores are expected to remain the same as prior to this commit. --- accounts/sql_migration_test.go | 17 +++------ accounts/store_sql.go | 66 ++++++++++++++++++++++----------- accounts/test_sql.go | 11 ++++-- accounts/test_sqlite.go | 12 ++++-- config_dev.go | 67 ++++++++++++++++++++++++++++++---- db/interfaces.go | 7 ++-- db/postgres.go | 21 +++-------- db/sql_migrations.go | 30 +++++++++++++++ db/sqlc/db_custom.go | 34 ++++------------- db/sqlite.go | 14 ++----- firewalldb/actions_sql.go | 8 ++-- firewalldb/kvstores_sql.go | 3 +- firewalldb/sql_store.go | 47 ++++++++++++++++++------ firewalldb/test_sql.go | 10 +++-- firewalldb/test_sqlite.go | 16 +++++--- session/sql_store.go | 64 +++++++++++++++++++++----------- session/test_sql.go | 11 ++++-- session/test_sqlite.go | 12 ++++-- 18 files changed, 293 insertions(+), 157 deletions(-) create mode 100644 db/sql_migrations.go diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index bfa508df5..382d23661 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDB(t, clock) @@ -45,13 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -343,7 +338,7 @@ func TestAccountStoreMigration(t *testing.T) { return MigrateAccountStoreToSQL( ctx, kvStore.db, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..c7e8ab070 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +34,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +56,13 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +72,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +179,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +321,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +399,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +434,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +475,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +497,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +529,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +546,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +656,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +699,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +726,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +751,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_sql.go b/accounts/test_sql.go index 3c1ee7f16..ca2f43d6f 100644 --- a/accounts/test_sql.go +++ b/accounts/test_sql.go @@ -5,15 +5,20 @@ package accounts import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 9d899b3e2..a31f990a6 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..82bd85cf0 100644 --- a/config_dev.go +++ b/config_dev.go @@ -8,9 +8,11 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -101,14 +103,36 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLlite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +140,41 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Postgres.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, db.LitdMigrationStreams, + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..bb39df9ea 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" ) var ( @@ -56,7 +57,7 @@ type BatchedTx[Q any] interface { txBody func(Q) error) error // Backend returns the type of the database backend used. - Backend() sqlc.BackendType + Backend() sqldb.BackendType } // Tx represents a database transaction that can be committed or rolled back. @@ -277,7 +278,7 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, } // Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { +func (t *TransactionExecutor[Q]) Backend() sqldb.BackendType { return t.BatchedQuerier.Backend() } @@ -301,7 +302,7 @@ func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { } // Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { +func (s *BaseDB) Backend() sqldb.BackendType { return s.Queries.Backend() } diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..962629be6 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -9,6 +9,7 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -119,7 +120,7 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDb.SetConnMaxLifetime(connMaxLifetime) rawDb.SetConnMaxIdleTime(connMaxIdleTime) - queries := sqlc.NewPostgres(rawDb) + queries := sqlc.NewForType(rawDb, sqldb.BackendTypePostgres) s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -128,15 +129,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { }, } - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - return s, nil } @@ -166,20 +158,17 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB(t, sqlFixture, LitdMigrationStreams) } // NewTestPostgresDBWithVersion is a helper function that creates a Postgres diff --git a/db/sql_migrations.go b/db/sql_migrations.go new file mode 100644 index 000000000..57d283aa3 --- /dev/null +++ b/db/sql_migrations.go @@ -0,0 +1,30 @@ +package db + +import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +var ( + LitdMigrationStream = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: sqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + LitdMigrationStreams = []sqldb.MigrationStream{LitdMigrationStream} +) diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..6f69a7e5b 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) @@ -132,7 +133,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(defaultConnMaxLifetime) - queries := sqlc.NewSqlite(db) + queries := sqlc.NewForType(db, sqldb.BackendTypeSqlite) s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -140,16 +141,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { Queries: queries, }, } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - + return s, nil } diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..4d5448313 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -12,7 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -167,7 +167,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +350,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 248892130..ce7714549 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -45,7 +46,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..1be887ace 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,9 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +20,20 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +43,31 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +78,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +111,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +127,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index a412441f8..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -55,8 +57,10 @@ func assertEqualActions(t *testing.T, expected, got *Action) { // createStore is a helper function that creates a new SQLDB and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - store := NewSQLDB(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 49b956d7d..ab184b5a6 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,23 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..26662a574 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -14,6 +14,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +22,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -51,12 +54,13 @@ type SQLQueries interface { var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +70,37 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) }, ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +303,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +347,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +380,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +439,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +450,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +480,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +513,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +543,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +598,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +639,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +688,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_sql.go b/session/test_sql.go index a83186069..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 0ceb0e046..84d946ce2 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,10 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +28,7 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, - ) + tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + + return createStore(t, tDb.BaseDB, clock) } From 2bb06d5b27496a0162280a0b07e04013eb88d693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:58:22 +0200 Subject: [PATCH 10/25] firewalldb: add `ListAllKVStoresRecords` to queries --- firewalldb/kvstores_sql.go | 1 + 1 file changed, 1 insertion(+) diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index ce7714549..0c1847706 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -31,6 +31,7 @@ type SQLKVStoreQueries interface { UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlc.Kvstore, error) DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) From b8c772b569c0154a8a7e4c8f318b920731e8121a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 12:17:17 +0200 Subject: [PATCH 11/25] firewalldb: rename sqlStore to store in mig test rename `sqlStore` to `store` in the firewalldb sql migration test file, to make the name shorted. This is done in preparation for future commits which will lengthen the lines where `sqlStore` is used, which otherwise would make the lines exceed the 80 character limit. --- firewalldb/sql_migration_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index 1298e3e53..787dc8051 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -75,7 +75,7 @@ func TestFirewallDBMigration(t *testing.T) { // The assertMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + assertMigrationResults := func(t *testing.T, store *SQLDB, kvEntries []*kvEntry) { var ( @@ -88,7 +88,7 @@ func TestFirewallDBMigration(t *testing.T) { getRuleID := func(ruleName string) int64 { ruleID, ok := ruleIDs[ruleName] if !ok { - ruleID, err = sqlStore.GetRuleID(ctx, ruleName) + ruleID, err = store.GetRuleID(ctx, ruleName) require.NoError(t, err) ruleIDs[ruleName] = ruleID @@ -100,7 +100,7 @@ func TestFirewallDBMigration(t *testing.T) { getGroupID := func(groupAlias []byte) int64 { groupID, ok := groupIDs[string(groupAlias)] if !ok { - groupID, err = sqlStore.GetSessionIDByAlias( + groupID, err = store.GetSessionIDByAlias( ctx, groupAlias, ) require.NoError(t, err) @@ -114,7 +114,7 @@ func TestFirewallDBMigration(t *testing.T) { getFeatureID := func(featureName string) int64 { featureID, ok := featureIDs[featureName] if !ok { - featureID, err = sqlStore.GetFeatureID( + featureID, err = store.GetFeatureID( ctx, featureName, ) require.NoError(t, err) @@ -128,7 +128,7 @@ func TestFirewallDBMigration(t *testing.T) { // First we extract all migrated kv entries from the SQLDB, // in order to be able to compare them to the original kv // entries, to ensure that the migration was successful. - sqlKvEntries, err := sqlStore.ListAllKVStoresRecords(ctx) + sqlKvEntries, err := store.ListAllKVStoresRecords(ctx) require.NoError(t, err) require.Equal(t, len(kvEntries), len(sqlKvEntries)) @@ -144,7 +144,7 @@ func TestFirewallDBMigration(t *testing.T) { ruleID := getRuleID(entry.ruleName) if entry.groupAlias.IsNone() { - sqlVal, err := sqlStore.GetGlobalKVStoreRecord( + sqlVal, err := store.GetGlobalKVStoreRecord( ctx, sqlc.GetGlobalKVStoreRecordParams{ Key: entry.key, @@ -162,7 +162,7 @@ func TestFirewallDBMigration(t *testing.T) { groupAlias := entry.groupAlias.UnwrapOrFail(t) groupID := getGroupID(groupAlias[:]) - v, err := sqlStore.GetGroupKVStoreRecord( + v, err := store.GetGroupKVStoreRecord( ctx, sqlc.GetGroupKVStoreRecordParams{ Key: entry.key, @@ -187,7 +187,7 @@ func TestFirewallDBMigration(t *testing.T) { entry.featureName.UnwrapOrFail(t), ) - sqlVal, err := sqlStore.GetFeatureKVStoreRecord( + sqlVal, err := store.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: entry.key, From efe82a6da6bf1a56bf469b35ba41519e886d9401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 12:21:00 +0200 Subject: [PATCH 12/25] firewalldb: use sqldb/v2 in the firewalldb mig test --- firewalldb/sql_migration_test.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index 787dc8051..2dd3604e7 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -9,12 +9,11 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" ) @@ -54,7 +53,7 @@ func TestFirewallDBMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) @@ -63,13 +62,9 @@ func TestFirewallDBMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } // The assertMigrationResults function will currently assert that @@ -88,7 +83,9 @@ func TestFirewallDBMigration(t *testing.T) { getRuleID := func(ruleName string) int64 { ruleID, ok := ruleIDs[ruleName] if !ok { - ruleID, err = store.GetRuleID(ctx, ruleName) + ruleID, err = store.db.GetRuleID( + ctx, ruleName, + ) require.NoError(t, err) ruleIDs[ruleName] = ruleID @@ -100,7 +97,7 @@ func TestFirewallDBMigration(t *testing.T) { getGroupID := func(groupAlias []byte) int64 { groupID, ok := groupIDs[string(groupAlias)] if !ok { - groupID, err = store.GetSessionIDByAlias( + groupID, err = store.db.GetSessionIDByAlias( ctx, groupAlias, ) require.NoError(t, err) @@ -114,7 +111,7 @@ func TestFirewallDBMigration(t *testing.T) { getFeatureID := func(featureName string) int64 { featureID, ok := featureIDs[featureName] if !ok { - featureID, err = store.GetFeatureID( + featureID, err = store.db.GetFeatureID( ctx, featureName, ) require.NoError(t, err) @@ -128,7 +125,7 @@ func TestFirewallDBMigration(t *testing.T) { // First we extract all migrated kv entries from the SQLDB, // in order to be able to compare them to the original kv // entries, to ensure that the migration was successful. - sqlKvEntries, err := store.ListAllKVStoresRecords(ctx) + sqlKvEntries, err := store.db.ListAllKVStoresRecords(ctx) require.NoError(t, err) require.Equal(t, len(kvEntries), len(sqlKvEntries)) @@ -144,7 +141,7 @@ func TestFirewallDBMigration(t *testing.T) { ruleID := getRuleID(entry.ruleName) if entry.groupAlias.IsNone() { - sqlVal, err := store.GetGlobalKVStoreRecord( + sqlVal, err := store.db.GetGlobalKVStoreRecord( ctx, sqlc.GetGlobalKVStoreRecordParams{ Key: entry.key, @@ -162,7 +159,7 @@ func TestFirewallDBMigration(t *testing.T) { groupAlias := entry.groupAlias.UnwrapOrFail(t) groupID := getGroupID(groupAlias[:]) - v, err := store.GetGroupKVStoreRecord( + v, err := store.db.GetGroupKVStoreRecord( ctx, sqlc.GetGroupKVStoreRecordParams{ Key: entry.key, @@ -187,7 +184,7 @@ func TestFirewallDBMigration(t *testing.T) { entry.featureName.UnwrapOrFail(t), ) - sqlVal, err := store.GetFeatureKVStoreRecord( + sqlVal, err := store.db.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: entry.key, @@ -302,7 +299,7 @@ func TestFirewallDBMigration(t *testing.T) { return MigrateFirewallDBToSQL( ctx, firewallStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) From f153c6579a3a103124d6fa5754117c12f5fc639e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 12:21:59 +0200 Subject: [PATCH 13/25] session: use sqldb/v2 in the session mig test --- session/sql_migration_test.go | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/session/sql_migration_test.go b/session/sql_migration_test.go index dfe495628..20cd1092b 100644 --- a/session/sql_migration_test.go +++ b/session/sql_migration_test.go @@ -2,17 +2,16 @@ package session import ( "context" - "database/sql" "fmt" "testing" "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "go.etcd.io/bbolt" "golang.org/x/exp/rand" @@ -38,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLQueriesExecutor[SQLQueries]) { // Create a sql store with a linked account store. testDBStore := NewTestDBWithAccounts(t, clock, acctStore) @@ -48,13 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLQueriesExecutor(baseDB, queries) } // assertMigrationResults asserts that the sql store contains the @@ -375,7 +370,7 @@ func TestSessionsStoreMigration(t *testing.T) { return MigrateSessionStoreToSQL( ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) From 3031c8393edde48fe3ed4786ba08819c52480fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:16:52 +0200 Subject: [PATCH 14/25] sqlcmig6: add `sqlcmig6` package This commit introduces the `sqlcmig6` package, which at the time of this commit contains the same queries and models as `sqlc` package. Importantly though, once the kvdb to sql migration is made available in production, the `sqlcmig6` package will not change, as it is intended to represent the sql db as it was at the time of the migration. The sqlcmig6 package is therefore intended to be used in the kvdb to sql migration code, as it is will always be compatible with the sql database when all sql migrations prior to the kvdb to sql migration are applied. When additional sql migrations are added in the future, they may effect the `sqlc` package in such a way that the standard `sqlc` queries and models aren't compatible with kvdb to sql migration code any longer. By preserving the `sqlcmig6` package, we ensure that the kvdb to sql migration code can always use the same queries and models that were available at the time of the migration, even if the `sqlc` package changes in the future. Note that the `sqlcmig6` package have not been generated by `sqlc` (the queries and models are copied from the `sqlc` package), as it is not intended to be changed in the future. --- db/sqlcmig6/accounts.sql.go | 390 ++++++++++++++++++ db/sqlcmig6/actions.sql.go | 73 ++++ db/sqlcmig6/actions_custom.go | 210 ++++++++++ db/sqlcmig6/db.go | 27 ++ db/sqlcmig6/db_custom.go | 48 +++ db/sqlcmig6/kvstores.sql.go | 376 +++++++++++++++++ db/sqlcmig6/models.go | 124 ++++++ db/sqlcmig6/privacy_paris.sql.go | 91 +++++ db/sqlcmig6/querier.go | 75 ++++ db/sqlcmig6/sessions.sql.go | 675 +++++++++++++++++++++++++++++++ 10 files changed, 2089 insertions(+) create mode 100644 db/sqlcmig6/accounts.sql.go create mode 100644 db/sqlcmig6/actions.sql.go create mode 100644 db/sqlcmig6/actions_custom.go create mode 100644 db/sqlcmig6/db.go create mode 100644 db/sqlcmig6/db_custom.go create mode 100644 db/sqlcmig6/kvstores.sql.go create mode 100644 db/sqlcmig6/models.go create mode 100644 db/sqlcmig6/privacy_paris.sql.go create mode 100644 db/sqlcmig6/querier.go create mode 100644 db/sqlcmig6/sessions.sql.go diff --git a/db/sqlcmig6/accounts.sql.go b/db/sqlcmig6/accounts.sql.go new file mode 100644 index 000000000..479c82b36 --- /dev/null +++ b/db/sqlcmig6/accounts.sql.go @@ -0,0 +1,390 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const addAccountInvoice = `-- name: AddAccountInvoice :exec +INSERT INTO account_invoices (account_id, hash) +VALUES ($1, $2) +` + +type AddAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error { + _, err := q.db.ExecContext(ctx, addAccountInvoice, arg.AccountID, arg.Hash) + return err +} + +const deleteAccount = `-- name: DeleteAccount :exec +DELETE FROM accounts +WHERE id = $1 +` + +func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAccount, id) + return err +} + +const deleteAccountPayment = `-- name: DeleteAccountPayment :exec +DELETE FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type DeleteAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, deleteAccountPayment, arg.Hash, arg.AccountID) + return err +} + +const getAccount = `-- name: GetAccount :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE id = $1 +` + +func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountByLabel = `-- name: GetAccountByLabel :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE label = $1 +` + +func (q *Queries) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccountByLabel, label) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountIDByAlias = `-- name: GetAccountIDByAlias :one +SELECT id +FROM accounts +WHERE alias = $1 +` + +func (q *Queries) GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getAccountIndex = `-- name: GetAccountIndex :one +SELECT value +FROM account_indices +WHERE name = $1 +` + +func (q *Queries) GetAccountIndex(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIndex, name) + var value int64 + err := row.Scan(&value) + return value, err +} + +const getAccountInvoice = `-- name: GetAccountInvoice :one +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 + AND hash = $2 +` + +type GetAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) { + row := q.db.QueryRowContext(ctx, getAccountInvoice, arg.AccountID, arg.Hash) + var i AccountInvoice + err := row.Scan(&i.AccountID, &i.Hash) + return i, err +} + +const getAccountPayment = `-- name: GetAccountPayment :one +SELECT account_id, hash, status, full_amount_msat FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type GetAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) { + row := q.db.QueryRowContext(ctx, getAccountPayment, arg.Hash, arg.AccountID) + var i AccountPayment + err := row.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ) + return i, err +} + +const insertAccount = `-- name: InsertAccount :one +INSERT INTO accounts (type, initial_balance_msat, current_balance_msat, last_updated, label, alias, expiration) +VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id +` + +type InsertAccountParams struct { + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Label sql.NullString + Alias int64 + Expiration time.Time +} + +func (q *Queries) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAccount, + arg.Type, + arg.InitialBalanceMsat, + arg.CurrentBalanceMsat, + arg.LastUpdated, + arg.Label, + arg.Alias, + arg.Expiration, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const listAccountInvoices = `-- name: ListAccountInvoices :many +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 +` + +func (q *Queries) ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) { + rows, err := q.db.QueryContext(ctx, listAccountInvoices, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountInvoice + for rows.Next() { + var i AccountInvoice + if err := rows.Scan(&i.AccountID, &i.Hash); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAccountPayments = `-- name: ListAccountPayments :many +SELECT account_id, hash, status, full_amount_msat +FROM account_payments +WHERE account_id = $1 +` + +func (q *Queries) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) { + rows, err := q.db.QueryContext(ctx, listAccountPayments, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountPayment + for rows.Next() { + var i AccountPayment + if err := rows.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAllAccounts = `-- name: ListAllAccounts :many +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +ORDER BY id +` + +func (q *Queries) ListAllAccounts(ctx context.Context) ([]Account, error) { + rows, err := q.db.QueryContext(ctx, listAllAccounts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Account + for rows.Next() { + var i Account + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setAccountIndex = `-- name: SetAccountIndex :exec +INSERT INTO account_indices (name, value) +VALUES ($1, $2) + ON CONFLICT (name) +DO UPDATE SET value = $2 +` + +type SetAccountIndexParams struct { + Name string + Value int64 +} + +func (q *Queries) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error { + _, err := q.db.ExecContext(ctx, setAccountIndex, arg.Name, arg.Value) + return err +} + +const updateAccountBalance = `-- name: UpdateAccountBalance :one +UPDATE accounts +SET current_balance_msat = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountBalanceParams struct { + CurrentBalanceMsat int64 + ID int64 +} + +func (q *Queries) UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountBalance, arg.CurrentBalanceMsat, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountExpiry = `-- name: UpdateAccountExpiry :one +UPDATE accounts +SET expiration = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountExpiryParams struct { + Expiration time.Time + ID int64 +} + +func (q *Queries) UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountExpiry, arg.Expiration, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountLastUpdate = `-- name: UpdateAccountLastUpdate :one +UPDATE accounts +SET last_updated = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountLastUpdateParams struct { + LastUpdated time.Time + ID int64 +} + +func (q *Queries) UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountLastUpdate, arg.LastUpdated, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertAccountPayment = `-- name: UpsertAccountPayment :exec +INSERT INTO account_payments (account_id, hash, status, full_amount_msat) +VALUES ($1, $2, $3, $4) +ON CONFLICT (account_id, hash) +DO UPDATE SET status = $3, full_amount_msat = $4 +` + +type UpsertAccountPaymentParams struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +func (q *Queries) UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, upsertAccountPayment, + arg.AccountID, + arg.Hash, + arg.Status, + arg.FullAmountMsat, + ) + return err +} diff --git a/db/sqlcmig6/actions.sql.go b/db/sqlcmig6/actions.sql.go new file mode 100644 index 000000000..a39d51e5d --- /dev/null +++ b/db/sqlcmig6/actions.sql.go @@ -0,0 +1,73 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const insertAction = `-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id +` + +type InsertActionParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +func (q *Queries) InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAction, + arg.SessionID, + arg.AccountID, + arg.MacaroonIdentifier, + arg.ActorName, + arg.FeatureName, + arg.ActionTrigger, + arg.Intent, + arg.StructuredJsonData, + arg.RpcMethod, + arg.RpcParamsJson, + arg.CreatedAt, + arg.ActionState, + arg.ErrorReason, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const setActionState = `-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3 +` + +type SetActionStateParams struct { + ActionState int16 + ErrorReason sql.NullString + ID int64 +} + +func (q *Queries) SetActionState(ctx context.Context, arg SetActionStateParams) error { + _, err := q.db.ExecContext(ctx, setActionState, arg.ActionState, arg.ErrorReason, arg.ID) + return err +} diff --git a/db/sqlcmig6/actions_custom.go b/db/sqlcmig6/actions_custom.go new file mode 100644 index 000000000..f01772d51 --- /dev/null +++ b/db/sqlcmig6/actions_custom.go @@ -0,0 +1,210 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "strconv" + "strings" +) + +// ActionQueryParams defines the parameters for querying actions. +type ActionQueryParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + FeatureName sql.NullString + ActorName sql.NullString + RpcMethod sql.NullString + State sql.NullInt16 + EndTime sql.NullTime + StartTime sql.NullTime + GroupID sql.NullInt64 +} + +// ListActionsParams defines the parameters for listing actions, including +// the ActionQueryParams for filtering and a Pagination struct for +// pagination. The Reversed field indicates whether the results should be +// returned in reverse order based on the created_at timestamp. +type ListActionsParams struct { + ActionQueryParams + Reversed bool + *Pagination +} + +// Pagination defines the pagination parameters for listing actions. +type Pagination struct { + NumOffset int32 + NumLimit int32 +} + +// ListActions retrieves a list of actions based on the provided +// ListActionsParams. +func (q *Queries) ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) { + + query, args := buildListActionsQuery(arg) + rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Action + for rows.Next() { + var i Action + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.AccountID, + &i.MacaroonIdentifier, + &i.ActorName, + &i.FeatureName, + &i.ActionTrigger, + &i.Intent, + &i.StructuredJsonData, + &i.RpcMethod, + &i.RpcParamsJson, + &i.CreatedAt, + &i.ActionState, + &i.ErrorReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// CountActions returns the number of actions that match the provided +// ActionQueryParams. +func (q *Queries) CountActions(ctx context.Context, + arg ActionQueryParams) (int64, error) { + + query, args := buildActionsQuery(arg, true) + row := q.db.QueryRowContext(ctx, query, args...) + + var count int64 + err := row.Scan(&count) + + return count, err +} + +// buildActionsQuery constructs a SQL query to retrieve actions based on the +// provided parameters. We do this manually so that if, for example, we have +// a sessionID we are filtering by, then this appears in the query as: +// `WHERE a.session_id = ?` which will properly make use of the underlying +// index. If we were instead to use a single SQLC query, it would include many +// WHERE clauses like: +// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)". +// This would use the index if run against postres but not when run against +// sqlite. +// +// The 'count' param indicates whether the query should return a count of +// actions that match the criteria or the actions themselves. +func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) { + var ( + conditions []string + args []any + ) + + if params.SessionID.Valid { + conditions = append(conditions, "a.session_id = ?") + args = append(args, params.SessionID.Int64) + } + if params.AccountID.Valid { + conditions = append(conditions, "a.account_id = ?") + args = append(args, params.AccountID.Int64) + } + if params.FeatureName.Valid { + conditions = append(conditions, "a.feature_name = ?") + args = append(args, params.FeatureName.String) + } + if params.ActorName.Valid { + conditions = append(conditions, "a.actor_name = ?") + args = append(args, params.ActorName.String) + } + if params.RpcMethod.Valid { + conditions = append(conditions, "a.rpc_method = ?") + args = append(args, params.RpcMethod.String) + } + if params.State.Valid { + conditions = append(conditions, "a.action_state = ?") + args = append(args, params.State.Int16) + } + if params.EndTime.Valid { + conditions = append(conditions, "a.created_at <= ?") + args = append(args, params.EndTime.Time) + } + if params.StartTime.Valid { + conditions = append(conditions, "a.created_at >= ?") + args = append(args, params.StartTime.Time) + } + if params.GroupID.Valid { + conditions = append(conditions, ` + EXISTS ( + SELECT 1 + FROM sessions s + WHERE s.id = a.session_id AND s.group_id = ? + )`) + args = append(args, params.GroupID.Int64) + } + + query := "SELECT a.* FROM actions a" + if count { + query = "SELECT COUNT(*) FROM actions a" + } + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + + return query, args +} + +// buildListActionsQuery constructs a SQL query to retrieve a list of actions +// based on the provided parameters. It builds upon the `buildActionsQuery` +// function, adding pagination and ordering based on the reversed parameter. +func buildListActionsQuery(params ListActionsParams) (string, []interface{}) { + query, args := buildActionsQuery(params.ActionQueryParams, false) + + // Determine order direction. + order := "ASC" + if params.Reversed { + order = "DESC" + } + query += " ORDER BY a.created_at " + order + + // Maybe paginate. + if params.Pagination != nil { + query += " LIMIT ? OFFSET ?" + args = append(args, params.NumLimit, params.NumOffset) + } + + return query, args +} + +// fillPlaceHolders replaces all '?' placeholders in the SQL query with +// positional placeholders like $1, $2, etc. This is necessary for +// compatibility with Postgres. +func fillPlaceHolders(query string) string { + var ( + sb strings.Builder + argNum = 1 + ) + + for i := range len(query) { + if query[i] != '?' { + sb.WriteByte(query[i]) + continue + } + + sb.WriteString("$") + sb.WriteString(strconv.Itoa(argNum)) + argNum++ + } + + return sb.String() +} diff --git a/db/sqlcmig6/db.go b/db/sqlcmig6/db.go new file mode 100644 index 000000000..82ff72dd8 --- /dev/null +++ b/db/sqlcmig6/db.go @@ -0,0 +1,27 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/db/sqlcmig6/db_custom.go b/db/sqlcmig6/db_custom.go new file mode 100644 index 000000000..e128cb290 --- /dev/null +++ b/db/sqlcmig6/db_custom.go @@ -0,0 +1,48 @@ +package sqlcmig6 + +import ( + "context" + + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// wrappedTX is a wrapper around a DBTX that also stores the database backend +// type. +type wrappedTX struct { + DBTX + + backendType sqldb.BackendType +} + +// Backend returns the type of database backend we're using. +func (q *Queries) Backend() sqldb.BackendType { + wtx, ok := q.db.(*wrappedTX) + if !ok { + // Shouldn't happen unless a new database backend type is added + // but not initialized correctly. + return sqldb.BackendTypeUnknown + } + + return wtx.backendType +} + +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} +} + +// CustomQueries defines a set of custom queries that we define in addition +// to the ones generated by sqlc. +type CustomQueries interface { + // CountActions returns the number of actions that match the provided + // ActionQueryParams. + CountActions(ctx context.Context, arg ActionQueryParams) (int64, error) + + // ListActions retrieves a list of actions based on the provided + // ListActionsParams. + ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) + + // Backend returns the type of the database backend used. + Backend() sqldb.BackendType +} diff --git a/db/sqlcmig6/kvstores.sql.go b/db/sqlcmig6/kvstores.sql.go new file mode 100644 index 000000000..3c076c5a4 --- /dev/null +++ b/db/sqlcmig6/kvstores.sql.go @@ -0,0 +1,376 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +const deleteAllTempKVStores = `-- name: DeleteAllTempKVStores :exec +DELETE FROM kvstores +WHERE perm = false +` + +func (q *Queries) DeleteAllTempKVStores(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteAllTempKVStores) + return err +} + +const deleteFeatureKVStoreRecord = `-- name: DeleteFeatureKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type DeleteFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const deleteGlobalKVStoreRecord = `-- name: DeleteGlobalKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type DeleteGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + return err +} + +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} + +const getFeatureID = `-- name: GetFeatureID :one +SELECT id +FROM features +WHERE name = $1 +` + +func (q *Queries) GetFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getFeatureKVStoreRecord = `-- name: GetFeatureKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type GetFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGlobalKVStoreRecord = `-- name: GetGlobalKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type GetGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one +INSERT INTO features (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getOrInsertRuleID = `-- name: GetOrInsertRuleID :one +INSERT INTO rules (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getRuleID = `-- name: GetRuleID :one +SELECT id +FROM rules +WHERE name = $1 +` + +func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) +VALUES ($1, $2, $3, $4, $5, $6) +` + +type InsertKVStoreRecordParams struct { + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, insertKVStoreRecord, + arg.Perm, + arg.RuleID, + arg.GroupID, + arg.FeatureID, + arg.EntryKey, + arg.Value, + ) + return err +} + +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.GroupID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id = $6 +` + +type UpdateFeatureKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateFeatureKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const updateGlobalKVStoreRecord = `-- name: UpdateGlobalKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id IS NULL + AND feature_id IS NULL +` + +type UpdateGlobalKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGlobalKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + ) + return err +} + +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id IS NULL +` + +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} diff --git a/db/sqlcmig6/models.go b/db/sqlcmig6/models.go new file mode 100644 index 000000000..9e57c28eb --- /dev/null +++ b/db/sqlcmig6/models.go @@ -0,0 +1,124 @@ +package sqlcmig6 + +import ( + "database/sql" + "time" +) + +type Account struct { + ID int64 + Alias int64 + Label sql.NullString + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Expiration time.Time +} + +type AccountIndex struct { + Name string + Value int64 +} + +type AccountInvoice struct { + AccountID int64 + Hash []byte +} + +type AccountPayment struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +type Action struct { + ID int64 + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +type Feature struct { + ID int64 + Name string +} + +type Kvstore struct { + ID int64 + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +type PrivacyPair struct { + GroupID int64 + RealVal string + PseudoVal string +} + +type Rule struct { + ID int64 + Name string +} + +type Session struct { + ID int64 + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + RevokedAt sql.NullTime + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + AccountID sql.NullInt64 + GroupID sql.NullInt64 +} + +type SessionFeatureConfig struct { + SessionID int64 + FeatureName string + Config []byte +} + +type SessionMacaroonCaveat struct { + ID int64 + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +type SessionMacaroonPermission struct { + ID int64 + SessionID int64 + Entity string + Action string +} + +type SessionPrivacyFlag struct { + SessionID int64 + Flag int32 +} diff --git a/db/sqlcmig6/privacy_paris.sql.go b/db/sqlcmig6/privacy_paris.sql.go new file mode 100644 index 000000000..20af09a2f --- /dev/null +++ b/db/sqlcmig6/privacy_paris.sql.go @@ -0,0 +1,91 @@ +package sqlcmig6 + +import ( + "context" +) + +const getAllPrivacyPairs = `-- name: GetAllPrivacyPairs :many +SELECT real_val, pseudo_val +FROM privacy_pairs +WHERE group_id = $1 +` + +type GetAllPrivacyPairsRow struct { + RealVal string + PseudoVal string +} + +func (q *Queries) GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) { + rows, err := q.db.QueryContext(ctx, getAllPrivacyPairs, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllPrivacyPairsRow + for rows.Next() { + var i GetAllPrivacyPairsRow + if err := rows.Scan(&i.RealVal, &i.PseudoVal); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getPseudoForReal = `-- name: GetPseudoForReal :one +SELECT pseudo_val +FROM privacy_pairs +WHERE group_id = $1 AND real_val = $2 +` + +type GetPseudoForRealParams struct { + GroupID int64 + RealVal string +} + +func (q *Queries) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) { + row := q.db.QueryRowContext(ctx, getPseudoForReal, arg.GroupID, arg.RealVal) + var pseudo_val string + err := row.Scan(&pseudo_val) + return pseudo_val, err +} + +const getRealForPseudo = `-- name: GetRealForPseudo :one +SELECT real_val +FROM privacy_pairs +WHERE group_id = $1 AND pseudo_val = $2 +` + +type GetRealForPseudoParams struct { + GroupID int64 + PseudoVal string +} + +func (q *Queries) GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) { + row := q.db.QueryRowContext(ctx, getRealForPseudo, arg.GroupID, arg.PseudoVal) + var real_val string + err := row.Scan(&real_val) + return real_val, err +} + +const insertPrivacyPair = `-- name: InsertPrivacyPair :exec +INSERT INTO privacy_pairs (group_id, real_val, pseudo_val) +VALUES ($1, $2, $3) +` + +type InsertPrivacyPairParams struct { + GroupID int64 + RealVal string + PseudoVal string +} + +func (q *Queries) InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error { + _, err := q.db.ExecContext(ctx, insertPrivacyPair, arg.GroupID, arg.RealVal, arg.PseudoVal) + return err +} diff --git a/db/sqlcmig6/querier.go b/db/sqlcmig6/querier.go new file mode 100644 index 000000000..57e229b5f --- /dev/null +++ b/db/sqlcmig6/querier.go @@ -0,0 +1,75 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type Querier interface { + AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error + DeleteAllTempKVStores(ctx context.Context) error + DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccount(ctx context.Context, id int64) (Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) + GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) + GetRuleID(ctx context.Context, name string) (int64, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) + GetSessionByID(ctx context.Context, id int64) (Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) + GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) + InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) + InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error + InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error + InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) + InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error + ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) + ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) + ListSessions(ctx context.Context) ([]Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]Session, error) + ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) + SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + SetActionState(ctx context.Context, arg SetActionStateParams) error + SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error + SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error + SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error + UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error + UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error + UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error +} + +var _ Querier = (*Queries)(nil) diff --git a/db/sqlcmig6/sessions.sql.go b/db/sqlcmig6/sessions.sql.go new file mode 100644 index 000000000..bc492043e --- /dev/null +++ b/db/sqlcmig6/sessions.sql.go @@ -0,0 +1,675 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const deleteSessionsWithState = `-- name: DeleteSessionsWithState :exec +DELETE FROM sessions +WHERE state = $1 +` + +func (q *Queries) DeleteSessionsWithState(ctx context.Context, state int16) error { + _, err := q.db.ExecContext(ctx, deleteSessionsWithState, state) + return err +} + +const getAliasBySessionID = `-- name: GetAliasBySessionID :one +SELECT alias FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getAliasBySessionID, id) + var alias []byte + err := row.Scan(&alias) + return alias, err +} + +const getSessionAliasesInGroup = `-- name: GetSessionAliasesInGroup :many +SELECT alias FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) { + rows, err := q.db.QueryContext(ctx, getSessionAliasesInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items [][]byte + for rows.Next() { + var alias []byte + if err := rows.Scan(&alias); err != nil { + return nil, err + } + items = append(items, alias) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionByAlias = `-- name: GetSessionByAlias :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByAlias, alias) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByID = `-- name: GetSessionByID :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetSessionByID(ctx context.Context, id int64) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByID, id) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByLocalPublicKey = `-- name: GetSessionByLocalPublicKey :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE local_public_key = $1 +` + +func (q *Queries) GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByLocalPublicKey, localPublicKey) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionFeatureConfigs = `-- name: GetSessionFeatureConfigs :many +SELECT session_id, feature_name, config FROM session_feature_configs +WHERE session_id = $1 +` + +func (q *Queries) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) { + rows, err := q.db.QueryContext(ctx, getSessionFeatureConfigs, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionFeatureConfig + for rows.Next() { + var i SessionFeatureConfig + if err := rows.Scan(&i.SessionID, &i.FeatureName, &i.Config); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionIDByAlias = `-- name: GetSessionIDByAlias :one +SELECT id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) { + row := q.db.QueryRowContext(ctx, getSessionIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getSessionMacaroonCaveats = `-- name: GetSessionMacaroonCaveats :many +SELECT id, session_id, caveat_id, verification_id, location FROM session_macaroon_caveats +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonCaveats, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonCaveat + for rows.Next() { + var i SessionMacaroonCaveat + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.CaveatID, + &i.VerificationID, + &i.Location, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionMacaroonPermissions = `-- name: GetSessionMacaroonPermissions :many +SELECT id, session_id, entity, action FROM session_macaroon_permissions +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonPermissions, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonPermission + for rows.Next() { + var i SessionMacaroonPermission + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Entity, + &i.Action, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionPrivacyFlags = `-- name: GetSessionPrivacyFlags :many +SELECT session_id, flag FROM session_privacy_flags +WHERE session_id = $1 +` + +func (q *Queries) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) { + rows, err := q.db.QueryContext(ctx, getSessionPrivacyFlags, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionPrivacyFlag + for rows.Next() { + var i SessionPrivacyFlag + if err := rows.Scan(&i.SessionID, &i.Flag); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionsInGroup = `-- name: GetSessionsInGroup :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, getSessionsInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertSession = `-- name: InsertSession :one +INSERT INTO sessions ( + alias, label, state, type, expiry, created_at, + server_address, dev_server, macaroon_root_key, pairing_secret, + local_private_key, local_public_key, remote_public_key, privacy, group_id, account_id +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, + $13, $14, $15, $16 +) RETURNING id +` + +type InsertSessionParams struct { + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + GroupID sql.NullInt64 + AccountID sql.NullInt64 +} + +func (q *Queries) InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertSession, + arg.Alias, + arg.Label, + arg.State, + arg.Type, + arg.Expiry, + arg.CreatedAt, + arg.ServerAddress, + arg.DevServer, + arg.MacaroonRootKey, + arg.PairingSecret, + arg.LocalPrivateKey, + arg.LocalPublicKey, + arg.RemotePublicKey, + arg.Privacy, + arg.GroupID, + arg.AccountID, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertSessionFeatureConfig = `-- name: InsertSessionFeatureConfig :exec +INSERT INTO session_feature_configs ( + session_id, feature_name, config +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionFeatureConfigParams struct { + SessionID int64 + FeatureName string + Config []byte +} + +func (q *Queries) InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error { + _, err := q.db.ExecContext(ctx, insertSessionFeatureConfig, arg.SessionID, arg.FeatureName, arg.Config) + return err +} + +const insertSessionMacaroonCaveat = `-- name: InsertSessionMacaroonCaveat :exec +INSERT INTO session_macaroon_caveats ( + session_id, caveat_id, verification_id, location +) VALUES ( + $1, $2, $3, $4 +) +` + +type InsertSessionMacaroonCaveatParams struct { + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +func (q *Queries) InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonCaveat, + arg.SessionID, + arg.CaveatID, + arg.VerificationID, + arg.Location, + ) + return err +} + +const insertSessionMacaroonPermission = `-- name: InsertSessionMacaroonPermission :exec +INSERT INTO session_macaroon_permissions ( + session_id, entity, action +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionMacaroonPermissionParams struct { + SessionID int64 + Entity string + Action string +} + +func (q *Queries) InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonPermission, arg.SessionID, arg.Entity, arg.Action) + return err +} + +const insertSessionPrivacyFlag = `-- name: InsertSessionPrivacyFlag :exec +INSERT INTO session_privacy_flags ( + session_id, flag +) VALUES ( + $1, $2 +) +` + +type InsertSessionPrivacyFlagParams struct { + SessionID int64 + Flag int32 +} + +func (q *Queries) InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error { + _, err := q.db.ExecContext(ctx, insertSessionPrivacyFlag, arg.SessionID, arg.Flag) + return err +} + +const listSessions = `-- name: ListSessions :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +ORDER BY created_at +` + +func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByState = `-- name: ListSessionsByState :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE state = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByState, state) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByType = `-- name: ListSessionsByType :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE type = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByType, type_) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setSessionGroupID = `-- name: SetSessionGroupID :exec +UPDATE sessions +SET group_id = $1 +WHERE id = $2 +` + +type SetSessionGroupIDParams struct { + GroupID sql.NullInt64 + ID int64 +} + +func (q *Queries) SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error { + _, err := q.db.ExecContext(ctx, setSessionGroupID, arg.GroupID, arg.ID) + return err +} + +const setSessionRemotePublicKey = `-- name: SetSessionRemotePublicKey :exec +UPDATE sessions +SET remote_public_key = $1 +WHERE id = $2 +` + +type SetSessionRemotePublicKeyParams struct { + RemotePublicKey []byte + ID int64 +} + +func (q *Queries) SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error { + _, err := q.db.ExecContext(ctx, setSessionRemotePublicKey, arg.RemotePublicKey, arg.ID) + return err +} + +const setSessionRevokedAt = `-- name: SetSessionRevokedAt :exec +UPDATE sessions +SET revoked_at = $1 +WHERE id = $2 +` + +type SetSessionRevokedAtParams struct { + RevokedAt sql.NullTime + ID int64 +} + +func (q *Queries) SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error { + _, err := q.db.ExecContext(ctx, setSessionRevokedAt, arg.RevokedAt, arg.ID) + return err +} + +const updateSessionState = `-- name: UpdateSessionState :exec +UPDATE sessions +SET state = $1 +WHERE id = $2 +` + +type UpdateSessionStateParams struct { + State int16 + ID int64 +} + +func (q *Queries) UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error { + _, err := q.db.ExecContext(ctx, updateSessionState, arg.State, arg.ID) + return err +} From 8d558ce62eb004d7b83c27e7d2a4cdf6da2fc21d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:44:01 +0200 Subject: [PATCH 15/25] accounts: add SQLMig6Queries to `accounts` --- accounts/store_sql.go | 48 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/accounts/store_sql.go b/accounts/store_sql.go index c7e8ab070..13422315c 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -11,6 +11,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" @@ -56,6 +57,33 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with accounts related tables. +// +//nolint:lll +type SQLMig6Queries interface { + sqldb.BaseQuerier + + AddAccountInvoice(ctx context.Context, arg sqlcmig6.AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg sqlcmig6.DeleteAccountPaymentParams) error + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountPayment(ctx context.Context, arg sqlcmig6.GetAccountPaymentParams) (sqlcmig6.AccountPayment, error) + InsertAccount(ctx context.Context, arg sqlcmig6.InsertAccountParams) (int64, error) + ListAccountInvoices(ctx context.Context, id int64) ([]sqlcmig6.AccountInvoice, error) + ListAccountPayments(ctx context.Context, id int64) ([]sqlcmig6.AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]sqlcmig6.Account, error) + SetAccountIndex(ctx context.Context, arg sqlcmig6.SetAccountIndexParams) error + UpdateAccountBalance(ctx context.Context, arg sqlcmig6.UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg sqlcmig6.UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg sqlcmig6.UpdateAccountLastUpdateParams) (int64, error) + UpsertAccountPayment(ctx context.Context, arg sqlcmig6.UpsertAccountPaymentParams) error + GetAccountInvoice(ctx context.Context, arg sqlcmig6.GetAccountInvoiceParams) (sqlcmig6.AccountInvoice, error) +} + // BatchedSQLQueries combines the SQLQueries interface with the BatchedTx // interface, allowing for multiple queries to be executed in single SQL // transaction. @@ -97,6 +125,26 @@ func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, } } +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) + }, + ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + // NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries // storage backend. func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, From 7c1a6975f241e76fe581a0a39248b0fcf4c4bd0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:49:37 +0200 Subject: [PATCH 16/25] accounts: use `sqlcmig6` for kvdb to sql migration --- accounts/sql_migration.go | 87 +++++++++++++++++++++++++++++----- accounts/sql_migration_test.go | 10 ++-- 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/accounts/sql_migration.go b/accounts/sql_migration.go index c36b51c6f..94f88ac28 100644 --- a/accounts/sql_migration.go +++ b/accounts/sql_migration.go @@ -11,8 +11,11 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" "github.com/pmezard/go-difflib/difflib" ) @@ -27,7 +30,7 @@ var ( // the KV database to the SQL database. The migration is done in a single // transaction to ensure that all accounts are migrated or none at all. func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV accounts store to SQL") @@ -50,7 +53,7 @@ func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, // to the SQL database. The migration is done in a single transaction to ensure // that all accounts are migrated or none at all. func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts from KV to SQL") @@ -68,7 +71,7 @@ func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, kvAccount.ID, err) } - migratedAccount, err := getAndMarshalAccount( + migratedAccount, err := getAndMarshalMig6Account( ctx, tx, migratedAccountID, ) if err != nil { @@ -151,17 +154,79 @@ func getBBoltAccounts(db kvdb.Backend) ([]*OffChainBalanceAccount, error) { return accounts, nil } +// getAndMarshalAccount retrieves the account with the given ID. If the account +// cannot be found, then ErrAccNotFound is returned. +func getAndMarshalMig6Account(ctx context.Context, db SQLMig6Queries, + id int64) (*OffChainBalanceAccount, error) { + + dbAcct, err := db.GetAccount(ctx, id) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccNotFound + } else if err != nil { + return nil, err + } + + return marshalDBMig6Account(ctx, db, dbAcct) +} + +func marshalDBMig6Account(ctx context.Context, db SQLMig6Queries, + dbAcct sqlcmig6.Account) (*OffChainBalanceAccount, error) { + + alias, err := AccountIDFromInt64(dbAcct.Alias) + if err != nil { + return nil, err + } + + account := &OffChainBalanceAccount{ + ID: alias, + Type: AccountType(dbAcct.Type), + InitialBalance: lnwire.MilliSatoshi(dbAcct.InitialBalanceMsat), + CurrentBalance: dbAcct.CurrentBalanceMsat, + LastUpdate: dbAcct.LastUpdated.UTC(), + ExpirationDate: dbAcct.Expiration.UTC(), + Invoices: make(AccountInvoices), + Payments: make(AccountPayments), + Label: dbAcct.Label.String, + } + + invoices, err := db.ListAccountInvoices(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + for _, invoice := range invoices { + var hash lntypes.Hash + copy(hash[:], invoice.Hash) + account.Invoices[hash] = struct{}{} + } + + payments, err := db.ListAccountPayments(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + + for _, payment := range payments { + var hash lntypes.Hash + copy(hash[:], payment.Hash) + account.Payments[hash] = &PaymentEntry{ + Status: lnrpc.Payment_PaymentStatus(payment.Status), + FullAmount: lnwire.MilliSatoshi(payment.FullAmountMsat), + } + } + + return account, nil +} + // migrateSingleAccountToSQL runs the migration for a single account from the // KV database to the SQL database. func migrateSingleAccountToSQL(ctx context.Context, - tx SQLQueries, account *OffChainBalanceAccount) (int64, error) { + tx SQLMig6Queries, account *OffChainBalanceAccount) (int64, error) { accountAlias, err := account.ID.ToInt64() if err != nil { return 0, err } - insertAccountParams := sqlc.InsertAccountParams{ + insertAccountParams := sqlcmig6.InsertAccountParams{ Type: int16(account.Type), InitialBalanceMsat: int64(account.InitialBalance), CurrentBalanceMsat: account.CurrentBalance, @@ -180,7 +245,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash := range account.Invoices { - addInvoiceParams := sqlc.AddAccountInvoiceParams{ + addInvoiceParams := sqlcmig6.AddAccountInvoiceParams{ AccountID: sqlId, Hash: hash[:], } @@ -192,7 +257,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash, paymentEntry := range account.Payments { - upsertPaymentParams := sqlc.UpsertAccountPaymentParams{ + upsertPaymentParams := sqlcmig6.UpsertAccountPaymentParams{ AccountID: sqlId, Hash: hash[:], Status: int16(paymentEntry.Status), @@ -211,7 +276,7 @@ func migrateSingleAccountToSQL(ctx context.Context, // migrateAccountsIndicesToSQL runs the migration for the account indices from // the KV database to the SQL database. func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts indices from KV to SQL") @@ -233,7 +298,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, settleIndexName, settleIndex) } - setAddIndexParams := sqlc.SetAccountIndexParams{ + setAddIndexParams := sqlcmig6.SetAccountIndexParams{ Name: addIndexName, Value: int64(addIndex), } @@ -243,7 +308,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, return err } - setSettleIndexParams := sqlc.SetAccountIndexParams{ + setSettleIndexParams := sqlcmig6.SetAccountIndexParams{ Name: settleIndexName, Value: int64(settleIndex), } diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index 382d23661..6697f74e8 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" @@ -35,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *SQLQueriesExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { testDBStore := NewTestDB(t, clock) @@ -44,9 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - queries := sqlc.NewForType(baseDB, baseDB.BackendType) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, NewSQLQueriesExecutor(baseDB, queries) + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -334,7 +334,7 @@ func TestAccountStoreMigration(t *testing.T) { // Perform the migration. var opts sqldb.MigrationTxOptions err = txEx.ExecTx(ctx, &opts, - func(tx SQLQueries) error { + func(tx SQLMig6Queries) error { return MigrateAccountStoreToSQL( ctx, kvStore.db, tx, ) From 721af2b753bb5652dd20bf6f9586cd657baa1af0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:52:08 +0200 Subject: [PATCH 17/25] session: add SQLMig6Queries to `session` --- session/sql_store.go | 54 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/session/sql_store.go b/session/sql_store.go index 26662a574..a169dff9b 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -12,6 +12,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb/v2" @@ -52,6 +53,39 @@ type SQLQueries interface { GetAccount(ctx context.Context, id int64) (sqlc.Account, error) } +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used to +// interact with session related tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetSessionByID(ctx context.Context, id int64) (sqlcmig6.Session, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlcmig6.Session, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, legacyID []byte) (sqlcmig6.Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (sqlcmig6.Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionFeatureConfig, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonCaveat, error) + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionPrivacyFlag, error) + InsertSessionFeatureConfig(ctx context.Context, arg sqlcmig6.InsertSessionFeatureConfigParams) error + SetSessionRevokedAt(ctx context.Context, arg sqlcmig6.SetSessionRevokedAtParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg sqlcmig6.InsertSessionPrivacyFlagParams) error + InsertSession(ctx context.Context, arg sqlcmig6.InsertSessionParams) (int64, error) + ListSessions(ctx context.Context) ([]sqlcmig6.Session, error) + ListSessionsByType(ctx context.Context, sessionType int16) ([]sqlcmig6.Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]sqlcmig6.Session, error) + SetSessionRemotePublicKey(ctx context.Context, arg sqlcmig6.SetSessionRemotePublicKeyParams) error + SetSessionGroupID(ctx context.Context, arg sqlcmig6.SetSessionGroupIDParams) error + UpdateSessionState(ctx context.Context, arg sqlcmig6.UpdateSessionStateParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) +} + var _ Store = (*SQLStore)(nil) // BatchedSQLQueries combines the SQLQueries interface with the BatchedTx @@ -95,6 +129,26 @@ func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, } } +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) + }, + ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + // NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries // storage backend. func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, From ed9d317823411d8a9c191e06f2b7f7c23f9fd333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 01:56:25 +0200 Subject: [PATCH 18/25] session: use `sqlcmig6` for kvdb to sql migration --- session/sql_migration.go | 236 ++++++++++++++++++++++++++++++---- session/sql_migration_test.go | 10 +- 2 files changed, 218 insertions(+), 28 deletions(-) diff --git a/session/sql_migration.go b/session/sql_migration.go index 428cc0fce..b1caebeb4 100644 --- a/session/sql_migration.go +++ b/session/sql_migration.go @@ -9,12 +9,17 @@ import ( "reflect" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" + "github.com/lightninglabs/lightning-node-connect/mailbox" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb" "github.com/pmezard/go-difflib/difflib" "go.etcd.io/bbolt" + "gopkg.in/macaroon-bakery.v2/bakery" + "gopkg.in/macaroon.v2" ) var ( @@ -31,7 +36,7 @@ var ( // NOTE: As sessions may contain linked accounts, the accounts sql migration // MUST be run prior to this migration. func MigrateSessionStoreToSQL(ctx context.Context, kvStore *bbolt.DB, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV sessions store to SQL") @@ -118,7 +123,7 @@ func getBBoltSessions(db *bbolt.DB) ([]*Session, error) { // from the KV database to the SQL database, and validates that the migrated // sessions match the original sessions. func migrateSessionsToSQLAndValidate(ctx context.Context, - tx SQLQueries, kvSessions []*Session) error { + tx SQLMig6Queries, kvSessions []*Session) error { for _, kvSession := range kvSessions { err := migrateSingleSessionToSQL(ctx, tx, kvSession) @@ -127,18 +132,9 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, kvSession.ID, err) } - // Validate that the session was correctly migrated and matches - // the original session in the kv store. - sqlSess, err := tx.GetSessionByAlias(ctx, kvSession.ID[:]) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - err = ErrSessionNotFound - } - return fmt.Errorf("unable to get migrated session "+ - "from sql store: %w", err) - } - - migratedSession, err := unmarshalSession(ctx, tx, sqlSess) + migratedSession, err := getAndUnmarshalSession( + ctx, tx, kvSession.ID[:], + ) if err != nil { return fmt.Errorf("unable to unmarshal migrated "+ "session: %w", err) @@ -172,12 +168,206 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, return nil } +func getAndUnmarshalSession(ctx context.Context, + tx SQLMig6Queries, legacyID []byte) (*Session, error) { + + // Validate that the session was correctly migrated and matches + // the original session in the kv store. + sqlSess, err := tx.GetSessionByAlias(ctx, legacyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = ErrSessionNotFound + } + + return nil, fmt.Errorf("unable to get migrated session "+ + "from sql store: %w", err) + } + + migratedSession, err := unmarshalMig6Session(ctx, tx, sqlSess) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal migrated "+ + "session: %w", err) + } + + return migratedSession, nil + +} + +func unmarshalMig6Session(ctx context.Context, db SQLMig6Queries, + dbSess sqlcmig6.Session) (*Session, error) { + + var legacyGroupID ID + if dbSess.GroupID.Valid { + groupID, err := db.GetAliasBySessionID( + ctx, dbSess.GroupID.Int64, + ) + if err != nil { + return nil, fmt.Errorf("unable to get legacy group "+ + "Alias: %v", err) + } + + legacyGroupID, err = IDFromBytes(groupID) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", + err) + } + } + + var acctAlias fn.Option[accounts.AccountID] + if dbSess.AccountID.Valid { + account, err := db.GetAccount(ctx, dbSess.AccountID.Int64) + if err != nil { + return nil, fmt.Errorf("unable to get account: %v", err) + } + + accountAlias, err := accounts.AccountIDFromInt64(account.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get account ID: %v", err) + } + acctAlias = fn.Some(accountAlias) + } + + legacyID, err := IDFromBytes(dbSess.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", err) + } + + var revokedAt time.Time + if dbSess.RevokedAt.Valid { + revokedAt = dbSess.RevokedAt.Time + } + + localPriv, localPub := btcec.PrivKeyFromBytes(dbSess.LocalPrivateKey) + + var remotePub *btcec.PublicKey + if len(dbSess.RemotePublicKey) != 0 { + remotePub, err = btcec.ParsePubKey(dbSess.RemotePublicKey) + if err != nil { + return nil, fmt.Errorf("unable to parse remote "+ + "public key: %v", err) + } + } + + // Get the macaroon permissions if they exist. + perms, err := db.GetSessionMacaroonPermissions(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "permissions: %v", err) + } + + // Get the macaroon caveats if they exist. + caveats, err := db.GetSessionMacaroonCaveats(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "caveats: %v", err) + } + + var macRecipe *MacaroonRecipe + if perms != nil || caveats != nil { + macRecipe = &MacaroonRecipe{ + Permissions: unmarshalMig6MacPerms(perms), + Caveats: unmarshalMig6MacCaveats(caveats), + } + } + + // Get the feature configs if they exist. + featureConfigs, err := db.GetSessionFeatureConfigs(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get feature configs: %v", err) + } + + var featureCfgs *FeaturesConfig + if featureConfigs != nil { + featureCfgs = unmarshalMig6FeatureConfigs(featureConfigs) + } + + // Get the privacy flags if they exist. + privacyFlags, err := db.GetSessionPrivacyFlags(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get privacy flags: %v", err) + } + + var privFlags PrivacyFlags + if privacyFlags != nil { + privFlags = unmarshalMig6PrivacyFlags(privacyFlags) + } + + var pairingSecret [mailbox.NumPassphraseEntropyBytes]byte + copy(pairingSecret[:], dbSess.PairingSecret) + + return &Session{ + ID: legacyID, + Label: dbSess.Label, + State: State(dbSess.State), + Type: Type(dbSess.Type), + Expiry: dbSess.Expiry, + CreatedAt: dbSess.CreatedAt, + RevokedAt: revokedAt, + ServerAddr: dbSess.ServerAddress, + DevServer: dbSess.DevServer, + MacaroonRootKey: uint64(dbSess.MacaroonRootKey), + PairingSecret: pairingSecret, + LocalPrivateKey: localPriv, + LocalPublicKey: localPub, + RemotePublicKey: remotePub, + WithPrivacyMapper: dbSess.Privacy, + GroupID: legacyGroupID, + PrivacyFlags: privFlags, + MacaroonRecipe: macRecipe, + FeatureConfig: featureCfgs, + AccountID: acctAlias, + }, nil +} + +func unmarshalMig6MacPerms(dbPerms []sqlcmig6.SessionMacaroonPermission) []bakery.Op { + ops := make([]bakery.Op, len(dbPerms)) + for i, dbPerm := range dbPerms { + ops[i] = bakery.Op{ + Entity: dbPerm.Entity, + Action: dbPerm.Action, + } + } + + return ops +} + +func unmarshalMig6MacCaveats(dbCaveats []sqlcmig6.SessionMacaroonCaveat) []macaroon.Caveat { + caveats := make([]macaroon.Caveat, len(dbCaveats)) + for i, dbCaveat := range dbCaveats { + caveats[i] = macaroon.Caveat{ + Id: dbCaveat.CaveatID, + VerificationId: dbCaveat.VerificationID, + Location: dbCaveat.Location.String, + } + } + + return caveats +} + +func unmarshalMig6FeatureConfigs(dbConfigs []sqlcmig6.SessionFeatureConfig) *FeaturesConfig { + configs := make(FeaturesConfig, len(dbConfigs)) + for _, dbConfig := range dbConfigs { + configs[dbConfig.FeatureName] = dbConfig.Config + } + + return &configs +} + +func unmarshalMig6PrivacyFlags(dbFlags []sqlcmig6.SessionPrivacyFlag) PrivacyFlags { + flags := make(PrivacyFlags, len(dbFlags)) + for i, dbFlag := range dbFlags { + flags[i] = PrivacyFlag(dbFlag.Flag) + } + + return flags +} + // migrateSingleSessionToSQL runs the migration for a single session from the // KV database to the SQL database. Note that if the session links to an // account, the linked accounts store MUST have been migrated before that // session is migrated. func migrateSingleSessionToSQL(ctx context.Context, - tx SQLQueries, session *Session) error { + tx SQLMig6Queries, session *Session) error { var ( acctID sql.NullInt64 @@ -213,7 +403,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Proceed to insert the session into the sql db. - sqlId, err := tx.InsertSession(ctx, sqlc.InsertSessionParams{ + sqlId, err := tx.InsertSession(ctx, sqlcmig6.InsertSessionParams{ Alias: session.ID[:], Label: session.Label, State: int16(session.State), @@ -239,7 +429,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // has been created. if !session.RevokedAt.IsZero() { err = tx.SetSessionRevokedAt( - ctx, sqlc.SetSessionRevokedAtParams{ + ctx, sqlcmig6.SetSessionRevokedAtParams{ ID: sqlId, RevokedAt: sqldb.SQLTime( session.RevokedAt.UTC(), @@ -265,7 +455,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Now lets set the group ID for the session. - err = tx.SetSessionGroupID(ctx, sqlc.SetSessionGroupIDParams{ + err = tx.SetSessionGroupID(ctx, sqlcmig6.SetSessionGroupIDParams{ ID: sqlId, GroupID: sqldb.SQLInt64(groupID), }) @@ -279,7 +469,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // We start by inserting the macaroon permissions. for _, sessionPerm := range session.MacaroonRecipe.Permissions { err = tx.InsertSessionMacaroonPermission( - ctx, sqlc.InsertSessionMacaroonPermissionParams{ + ctx, sqlcmig6.InsertSessionMacaroonPermissionParams{ SessionID: sqlId, Entity: sessionPerm.Entity, Action: sessionPerm.Action, @@ -293,7 +483,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Next we insert the macaroon caveats. for _, caveat := range session.MacaroonRecipe.Caveats { err = tx.InsertSessionMacaroonCaveat( - ctx, sqlc.InsertSessionMacaroonCaveatParams{ + ctx, sqlcmig6.InsertSessionMacaroonCaveatParams{ SessionID: sqlId, CaveatID: caveat.Id, VerificationID: caveat.VerificationId, @@ -312,7 +502,7 @@ func migrateSingleSessionToSQL(ctx context.Context, if session.FeatureConfig != nil { for featureName, config := range *session.FeatureConfig { err = tx.InsertSessionFeatureConfig( - ctx, sqlc.InsertSessionFeatureConfigParams{ + ctx, sqlcmig6.InsertSessionFeatureConfigParams{ SessionID: sqlId, FeatureName: featureName, Config: config, @@ -327,7 +517,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Finally we insert the privacy flags. for _, privacyFlag := range session.PrivacyFlags { err = tx.InsertSessionPrivacyFlag( - ctx, sqlc.InsertSessionPrivacyFlagParams{ + ctx, sqlcmig6.InsertSessionPrivacyFlagParams{ SessionID: sqlId, Flag: int32(privacyFlag), }, diff --git a/session/sql_migration_test.go b/session/sql_migration_test.go index 20cd1092b..9ac75f30b 100644 --- a/session/sql_migration_test.go +++ b/session/sql_migration_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" @@ -37,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore, - *SQLQueriesExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { // Create a sql store with a linked account store. testDBStore := NewTestDBWithAccounts(t, clock, acctStore) @@ -47,9 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) { baseDB := store.BaseDB - queries := sqlc.NewForType(baseDB, baseDB.BackendType) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, NewSQLQueriesExecutor(baseDB, queries) + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } // assertMigrationResults asserts that the sql store contains the @@ -366,7 +366,7 @@ func TestSessionsStoreMigration(t *testing.T) { var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, &opts, func(tx SQLQueries) error { + ctx, &opts, func(tx SQLMig6Queries) error { return MigrateSessionStoreToSQL( ctx, kvStore.DB, tx, ) From ba82ac009ea20a8db0e33ccaaaa4a05ba40ce991 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 02:02:51 +0200 Subject: [PATCH 19/25] firewalldb: add SQLMig6Queries to `firewalldb` --- firewalldb/actions_sql.go | 22 ++++++++++++++++++++++ firewalldb/kvstores_sql.go | 27 +++++++++++++++++++++++++++ firewalldb/privacy_mapper_sql.go | 14 ++++++++++++++ firewalldb/sql_store.go | 31 +++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+) diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 4d5448313..eb0294af7 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -10,6 +10,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb/v2" @@ -22,6 +23,13 @@ type SQLAccountQueries interface { GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) } +// SQLMig6AccountQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the accounts table. +type SQLMig6AccountQueries interface { + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) +} + // SQLActionQueries is a subset of the sqlc.Queries interface that can be used // to interact with action related tables. // @@ -36,6 +44,20 @@ type SQLActionQueries interface { CountActions(ctx context.Context, arg sqlc.ActionQueryParams) (int64, error) } +// SQLMig6ActionQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with action related tables. +// +//nolint:lll +type SQLMig6ActionQueries interface { + SQLSessionQueries + SQLMig6AccountQueries + + InsertAction(ctx context.Context, arg sqlcmig6.InsertActionParams) (int64, error) + SetActionState(ctx context.Context, arg sqlcmig6.SetActionStateParams) error + ListActions(ctx context.Context, arg sqlcmig6.ListActionsParams) ([]sqlcmig6.Action, error) + CountActions(ctx context.Context, arg sqlcmig6.ActionQueryParams) (int64, error) +} + // sqlActionLocator helps us find an action in the SQL DB. type sqlActionLocator struct { // id is the DB level ID of the action. diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c1847706..f815742ac 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -9,6 +9,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb/v2" @@ -39,6 +40,32 @@ type SQLKVStoreQueries interface { GetRuleID(ctx context.Context, name string) (int64, error) } +// SQLMig6KVStoreQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the kvstore tables. +//  +// +//nolint:lll +type SQLMig6KVStoreQueries interface { + SQLSessionQueries + + DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGroupKVStoreRecordParams) error + GetFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGroupKVStoreRecordParams) ([]byte, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGroupKVStoreRecordParams) error + InsertKVStoreRecord(ctx context.Context, arg sqlcmig6.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlcmig6.Kvstore, error) + DeleteAllTempKVStores(ctx context.Context) error + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetRuleID(ctx context.Context, name string) (int64, error) +} + // DeleteTempKVStores deletes all temporary kv stores. // // NOTE: part of the RulesDB interface. diff --git a/firewalldb/privacy_mapper_sql.go b/firewalldb/privacy_mapper_sql.go index 8a4863a6c..ff7f70f94 100644 --- a/firewalldb/privacy_mapper_sql.go +++ b/firewalldb/privacy_mapper_sql.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" ) @@ -22,6 +23,19 @@ type SQLPrivacyPairQueries interface { GetRealForPseudo(ctx context.Context, arg sqlc.GetRealForPseudoParams) (string, error) } +// SQLMig6PrivacyPairQueries is a subset of the sqlcmig6.Queries interface that +// can be used to interact with the privacy map table. +// +//nolint:lll +type SQLMig6PrivacyPairQueries interface { + SQLSessionQueries + + InsertPrivacyPair(ctx context.Context, arg sqlcmig6.InsertPrivacyPairParams) error + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]sqlcmig6.GetAllPrivacyPairsRow, error) + GetPseudoForReal(ctx context.Context, arg sqlcmig6.GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg sqlcmig6.GetRealForPseudoParams) (string, error) +} + // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. // diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index 1be887ace..dcf201f1c 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -6,6 +6,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -27,6 +28,16 @@ type SQLQueries interface { SQLActionQueries } +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with various firewalldb tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + SQLMig6KVStoreQueries + SQLMig6PrivacyPairQueries + SQLMig6ActionQueries +} + // BatchedSQLQueries combines the SQLQueries interface with the BatchedTx // interface, allowing for multiple queries to be executed in single SQL // transaction. @@ -68,6 +79,26 @@ func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, } } +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) + }, + ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) From cd19ab9af82a28ec87e4253ea7a8eaf4d9f57eb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 11:17:41 +0200 Subject: [PATCH 20/25] firewalldb: use queries to assert migration results --- firewalldb/sql_migration_test.go | 39 +++++++++++++++++--------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index 2dd3604e7..ad76f6e7b 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -52,8 +52,8 @@ func TestFirewallDBMigration(t *testing.T) { t.Skipf("Skipping Firewall DB migration test for kvdb build") } - makeSQLDB := func(t *testing.T, sessionsStore session.Store) (*SQLDB, - *SQLQueriesExecutor[SQLQueries]) { + makeSQLDB := func(t *testing.T, + sessionsStore session.Store) *SQLQueriesExecutor[SQLQueries] { testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) @@ -64,13 +64,13 @@ func TestFirewallDBMigration(t *testing.T) { queries := sqlc.NewForType(baseDB, baseDB.BackendType) - return store, NewSQLQueriesExecutor(baseDB, queries) + return NewSQLQueriesExecutor(baseDB, queries) } // The assertMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, store *SQLDB, + assertMigrationResults := func(t *testing.T, store SQLQueries, kvEntries []*kvEntry) { var ( @@ -83,9 +83,7 @@ func TestFirewallDBMigration(t *testing.T) { getRuleID := func(ruleName string) int64 { ruleID, ok := ruleIDs[ruleName] if !ok { - ruleID, err = store.db.GetRuleID( - ctx, ruleName, - ) + ruleID, err = store.GetRuleID(ctx, ruleName) require.NoError(t, err) ruleIDs[ruleName] = ruleID @@ -97,7 +95,7 @@ func TestFirewallDBMigration(t *testing.T) { getGroupID := func(groupAlias []byte) int64 { groupID, ok := groupIDs[string(groupAlias)] if !ok { - groupID, err = store.db.GetSessionIDByAlias( + groupID, err = store.GetSessionIDByAlias( ctx, groupAlias, ) require.NoError(t, err) @@ -111,7 +109,7 @@ func TestFirewallDBMigration(t *testing.T) { getFeatureID := func(featureName string) int64 { featureID, ok := featureIDs[featureName] if !ok { - featureID, err = store.db.GetFeatureID( + featureID, err = store.GetFeatureID( ctx, featureName, ) require.NoError(t, err) @@ -125,7 +123,7 @@ func TestFirewallDBMigration(t *testing.T) { // First we extract all migrated kv entries from the SQLDB, // in order to be able to compare them to the original kv // entries, to ensure that the migration was successful. - sqlKvEntries, err := store.db.ListAllKVStoresRecords(ctx) + sqlKvEntries, err := store.ListAllKVStoresRecords(ctx) require.NoError(t, err) require.Equal(t, len(kvEntries), len(sqlKvEntries)) @@ -141,7 +139,7 @@ func TestFirewallDBMigration(t *testing.T) { ruleID := getRuleID(entry.ruleName) if entry.groupAlias.IsNone() { - sqlVal, err := store.db.GetGlobalKVStoreRecord( + sqlVal, err := store.GetGlobalKVStoreRecord( ctx, sqlc.GetGlobalKVStoreRecordParams{ Key: entry.key, @@ -159,7 +157,7 @@ func TestFirewallDBMigration(t *testing.T) { groupAlias := entry.groupAlias.UnwrapOrFail(t) groupID := getGroupID(groupAlias[:]) - v, err := store.db.GetGroupKVStoreRecord( + v, err := store.GetGroupKVStoreRecord( ctx, sqlc.GetGroupKVStoreRecordParams{ Key: entry.key, @@ -184,7 +182,7 @@ func TestFirewallDBMigration(t *testing.T) { entry.featureName.UnwrapOrFail(t), ) - sqlVal, err := store.db.GetFeatureKVStoreRecord( + sqlVal, err := store.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: entry.key, @@ -290,21 +288,26 @@ func TestFirewallDBMigration(t *testing.T) { // Create the SQL store that we will migrate the data // to. - sqlStore, txEx := makeSQLDB(t, sessionsStore) + txEx := makeSQLDB(t, sessionsStore) // Perform the migration. var opts sqldb.MigrationTxOptions err = txEx.ExecTx(ctx, &opts, func(tx SQLQueries) error { - return MigrateFirewallDBToSQL( + err = MigrateFirewallDBToSQL( ctx, firewallStore.DB, tx, ) + if err != nil { + return err + } + + // Assert migration results. + assertMigrationResults(t, tx, entries) + + return nil }, sqldb.NoOpReset, ) require.NoError(t, err) - - // Assert migration results. - assertMigrationResults(t, sqlStore, entries) }) } } From b70d7bbf3135db59e901d48d77b4fe6707f5b5f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 02:16:16 +0200 Subject: [PATCH 21/25] firewalldb: use `sqlcmig6` for kvdb to sql migration --- firewalldb/sql_migration.go | 18 +++++++++--------- firewalldb/sql_migration_test.go | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go index 1e114c12c..f898993e3 100644 --- a/firewalldb/sql_migration.go +++ b/firewalldb/sql_migration.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb" "go.etcd.io/bbolt" @@ -78,7 +78,7 @@ func (e *kvEntry) namespacedKey() string { // NOTE: As sessions may contain linked sessions and accounts, the sessions and // accounts sql migration MUST be run prior to this migration. func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, - sqlTx SQLQueries) error { + sqlTx SQLMig6Queries) error { log.Infof("Starting migration of the rules DB to SQL") @@ -99,7 +99,7 @@ func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, // database to the SQL database. The function also asserts that the // migrated values match the original values in the KV store. func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, - sqlTx SQLQueries) error { + sqlTx SQLMig6Queries) error { log.Infof("Starting migration of the KV stores to SQL") @@ -361,7 +361,7 @@ func collectKVPairs(bkt *bbolt.Bucket, errorOnBuckets, perm bool, } // insertPair inserts a single key-value pair into the SQL database. -func insertPair(ctx context.Context, tx SQLQueries, +func insertPair(ctx context.Context, tx SQLMig6Queries, entry *kvEntry) (*sqlKvEntry, error) { ruleID, err := tx.GetOrInsertRuleID(ctx, entry.ruleName) @@ -369,7 +369,7 @@ func insertPair(ctx context.Context, tx SQLQueries, return nil, err } - p := sqlc.InsertKVStoreRecordParams{ + p := sqlcmig6.InsertKVStoreRecordParams{ Perm: entry.perm, RuleID: ruleID, EntryKey: entry.key, @@ -421,13 +421,13 @@ func insertPair(ctx context.Context, tx SQLQueries, // getSQLValue retrieves the key value for the given kvEntry from the SQL // database. -func getSQLValue(ctx context.Context, tx SQLQueries, +func getSQLValue(ctx context.Context, tx SQLMig6Queries, entry *sqlKvEntry) ([]byte, error) { switch { case entry.featureID.Valid && entry.groupID.Valid: return tx.GetFeatureKVStoreRecord( - ctx, sqlc.GetFeatureKVStoreRecordParams{ + ctx, sqlcmig6.GetFeatureKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, GroupID: entry.groupID, @@ -437,7 +437,7 @@ func getSQLValue(ctx context.Context, tx SQLQueries, ) case entry.groupID.Valid: return tx.GetGroupKVStoreRecord( - ctx, sqlc.GetGroupKVStoreRecordParams{ + ctx, sqlcmig6.GetGroupKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, GroupID: entry.groupID, @@ -446,7 +446,7 @@ func getSQLValue(ctx context.Context, tx SQLQueries, ) case !entry.featureID.Valid && !entry.groupID.Valid: return tx.GetGlobalKVStoreRecord( - ctx, sqlc.GetGlobalKVStoreRecordParams{ + ctx, sqlcmig6.GetGlobalKVStoreRecordParams{ Perm: entry.perm, RuleID: entry.ruleID, Key: entry.key, diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index ad76f6e7b..ab8ea01da 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" @@ -53,24 +53,24 @@ func TestFirewallDBMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, - sessionsStore session.Store) *SQLQueriesExecutor[SQLQueries] { + sStore session.Store) *SQLMig6QueriesExecutor[SQLMig6Queries] { - testDBStore := NewTestDBWithSessions(t, sessionsStore, clock) + testDBStore := NewTestDBWithSessions(t, sStore, clock) store, ok := testDBStore.(*SQLDB) require.True(t, ok) baseDB := store.BaseDB - queries := sqlc.NewForType(baseDB, baseDB.BackendType) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return NewSQLQueriesExecutor(baseDB, queries) + return NewSQLMig6QueriesExecutor(baseDB, queries) } // The assertMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, store SQLQueries, + assertMigrationResults := func(t *testing.T, store SQLMig6Queries, kvEntries []*kvEntry) { var ( @@ -141,7 +141,7 @@ func TestFirewallDBMigration(t *testing.T) { if entry.groupAlias.IsNone() { sqlVal, err := store.GetGlobalKVStoreRecord( ctx, - sqlc.GetGlobalKVStoreRecordParams{ + sqlcmig6.GetGlobalKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -159,7 +159,7 @@ func TestFirewallDBMigration(t *testing.T) { v, err := store.GetGroupKVStoreRecord( ctx, - sqlc.GetGroupKVStoreRecordParams{ + sqlcmig6.GetGroupKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -184,7 +184,7 @@ func TestFirewallDBMigration(t *testing.T) { sqlVal, err := store.GetFeatureKVStoreRecord( ctx, - sqlc.GetFeatureKVStoreRecordParams{ + sqlcmig6.GetFeatureKVStoreRecordParams{ Key: entry.key, Perm: entry.perm, RuleID: ruleID, @@ -293,7 +293,7 @@ func TestFirewallDBMigration(t *testing.T) { // Perform the migration. var opts sqldb.MigrationTxOptions err = txEx.ExecTx(ctx, &opts, - func(tx SQLQueries) error { + func(tx SQLMig6Queries) error { err = MigrateFirewallDBToSQL( ctx, firewallStore.DB, tx, ) From 71ddae3528479ef7c11fa6928228e81d24e389f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 20:05:45 +0200 Subject: [PATCH 22/25] migrationstreams: introduce `migrationstreams` pkg --- accounts/test_sqlite.go | 9 +++++--- config_dev.go | 5 +++-- db/migrations.go | 2 +- db/migrationstreams/log.go | 25 +++++++++++++++++++++ db/{ => migrationstreams}/sql_migrations.go | 7 +++--- db/postgres.go | 7 ++++-- db/schemas.go | 2 +- db/sqlite.go | 2 +- firewalldb/test_sqlite.go | 9 +++++--- log.go | 5 +++++ session/test_sqlite.go | 9 +++++--- 11 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 db/migrationstreams/log.go rename db/{ => migrationstreams}/sql_migrations.go (82%) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index a31f990a6..e2d281e4c 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -18,7 +18,8 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { return createStore( - t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + t, + sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, clock, ) } @@ -28,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, migrationstreams.LitdMigrationStreams, + ) return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 82bd85cf0..d40acaa2e 100644 --- a/config_dev.go +++ b/config_dev.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" @@ -113,7 +114,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if !cfg.Sqlite.SkipMigrations { err = sqldb.ApplyAllMigrations( - sqlStore, db.LitdMigrationStreams, + sqlStore, migrationstreams.LitdMigrationStreams, ) if err != nil { return stores, fmt.Errorf("error applying "+ @@ -155,7 +156,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if !cfg.Postgres.SkipMigrations { err = sqldb.ApplyAllMigrations( - sqlStore, db.LitdMigrationStreams, + sqlStore, migrationstreams.LitdMigrationStreams, ) if err != nil { return stores, fmt.Errorf("error applying "+ diff --git a/db/migrations.go b/db/migrations.go index 79d63587e..55825e12f 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -120,7 +120,7 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, targetVersion MigrationTarget, opts *migrateOptions) error { // With the migrate instance open, we'll create a new migration source - // using the embedded file system stored in sqlSchemas. The library + // using the embedded file system stored in SqlSchemas. The library // we're using can't handle a raw file system interface, so we wrap it // in this intermediate layer. migrateFileServer, err := httpfs.New(http.FS(fs), path) diff --git a/db/migrationstreams/log.go b/db/migrationstreams/log.go new file mode 100644 index 000000000..0a3e80731 --- /dev/null +++ b/db/migrationstreams/log.go @@ -0,0 +1,25 @@ +package migrationstreams + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +const Subsystem = "MIGS" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/db/sql_migrations.go b/db/migrationstreams/sql_migrations.go similarity index 82% rename from db/sql_migrations.go rename to db/migrationstreams/sql_migrations.go index 57d283aa3..ce9602698 100644 --- a/db/sql_migrations.go +++ b/db/migrationstreams/sql_migrations.go @@ -1,8 +1,9 @@ -package db +package migrationstreams import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -10,14 +11,14 @@ var ( LitdMigrationStream = sqldb.MigrationStream{ MigrateTableName: pgx.DefaultMigrationsTable, SQLFileDirectory: "sqlc/migrations", - Schemas: sqlSchemas, + Schemas: db.SqlSchemas, // LatestMigrationVersion is the latest migration version of the // database. This is used to implement downgrade protection for // the daemon. // // NOTE: This MUST be updated when a new migration is added. - LatestMigrationVersion: LatestMigrationVersion, + LatestMigrationVersion: db.LatestMigrationVersion, MakePostMigrationChecks: func( db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, diff --git a/db/postgres.go b/db/postgres.go index 962629be6..546f6311d 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -8,6 +8,7 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" @@ -149,7 +150,7 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, return fmt.Errorf("error creating postgres migration: %w", err) } - postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) + postgresFS := newReplacerFS(SqlSchemas, postgresSchemaReplacements) return applyMigrations( postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, opts, @@ -168,7 +169,9 @@ func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { sqlFixture.TearDown(t) }) - return sqldb.NewTestPostgresDB(t, sqlFixture, LitdMigrationStreams) + return sqldb.NewTestPostgresDB( + t, sqlFixture, migrationstreams.LitdMigrationStreams, + ) } // NewTestPostgresDBWithVersion is a helper function that creates a Postgres diff --git a/db/schemas.go b/db/schemas.go index 1a7a2096f..dce7fa84c 100644 --- a/db/schemas.go +++ b/db/schemas.go @@ -6,4 +6,4 @@ import ( ) //go:embed sqlc/migrations/*.*.sql -var sqlSchemas embed.FS +var SqlSchemas embed.FS diff --git a/db/sqlite.go b/db/sqlite.go index 6f69a7e5b..a65271548 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -235,7 +235,7 @@ func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, return fmt.Errorf("error creating sqlite migration: %w", err) } - sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) + sqliteFS := newReplacerFS(SqlSchemas, sqliteSchemaReplacements) return applyMigrations( sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, ) diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index ab184b5a6..844833a1e 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -5,7 +5,7 @@ package firewalldb import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -13,7 +13,8 @@ import ( // NewTestDB is a helper function that creates an BBolt database for testing. func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore( - t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + t, + sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, clock, ) } @@ -23,7 +24,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { - tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, migrationstreams.LitdMigrationStreams, + ) return createStore(t, tDb.BaseDB, clock) } diff --git a/log.go b/log.go index b803f72d2..0535a819a 100644 --- a/log.go +++ b/log.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/firewall" "github.com/lightninglabs/lightning-terminal/firewalldb" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -91,6 +92,10 @@ func SetupLoggers(root *build.SubLoggerManager, intercept signal.Interceptor) { root, subservers.Subsystem, intercept, subservers.UseLogger, ) lnd.AddSubLogger(root, db.Subsystem, intercept, db.UseLogger) + lnd.AddSubLogger( + root, migrationstreams.Subsystem, intercept, + migrationstreams.UseLogger, + ) // Add daemon loggers to lnd's root logger. faraday.SetupLoggers(root, intercept) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 84d946ce2..1c3702444 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -18,7 +18,8 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { return createStore( - t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB, + t, + sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, clock, ) } @@ -28,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams) + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, migrationstreams.LitdMigrationStreams, + ) return createStore(t, tDb.BaseDB, clock) } From 1d9100ec9b677ebe2274dbfd33f6b067be8c70d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 21:38:29 +0200 Subject: [PATCH 23/25] multi: introduce dev migrations --- db/migrations.go | 9 +++ db/migrationstreams/sql_migrations.go | 2 + db/migrationstreams/sql_migrations_dev.go | 55 +++++++++++++++++++ db/schemas.go | 2 +- .../000001_dev_test_migration.down.sql | 1 + .../000001_dev_test_migration.up.sql | 1 + scripts/gen_sqlc_docker.sh | 4 +- 7 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 db/migrationstreams/sql_migrations_dev.go create mode 100644 db/sqlc/migrations_dev/000001_dev_test_migration.down.sql create mode 100644 db/sqlc/migrations_dev/000001_dev_test_migration.up.sql diff --git a/db/migrations.go b/db/migrations.go index 55825e12f..87811330f 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -23,6 +23,15 @@ const ( // // NOTE: This MUST be updated when a new migration is added. LatestMigrationVersion = 5 + + // LatestDevMigrationVersion is the latest dev migration version of the + // database. This is used to implement downgrade protection for the + // daemon. This represents the latest number used in the migrations_dev + // directory. + // + // NOTE: This MUST be updated when a migration is added or removed, from + // the migrations_dev directory. + LatestDevMigrationVersion = 1 ) // MigrationTarget is a functional option that can be passed to applyMigrations diff --git a/db/migrationstreams/sql_migrations.go b/db/migrationstreams/sql_migrations.go index ce9602698..5bcf8ba76 100644 --- a/db/migrationstreams/sql_migrations.go +++ b/db/migrationstreams/sql_migrations.go @@ -1,3 +1,5 @@ +//go:build !dev + package migrationstreams import ( diff --git a/db/migrationstreams/sql_migrations_dev.go b/db/migrationstreams/sql_migrations_dev.go new file mode 100644 index 000000000..bc32b39f0 --- /dev/null +++ b/db/migrationstreams/sql_migrations_dev.go @@ -0,0 +1,55 @@ +//go:build dev + +package migrationstreams + +import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +var ( + // Create the prod migration stream. + migStream = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: db.LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + // Create the dev migration stream. + migStreamDev = sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: db.LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + LitdMigrationStreams = []sqldb.MigrationStream{migStream, migStreamDev} +) diff --git a/db/schemas.go b/db/schemas.go index dce7fa84c..565fb5615 100644 --- a/db/schemas.go +++ b/db/schemas.go @@ -5,5 +5,5 @@ import ( _ "embed" ) -//go:embed sqlc/migrations/*.*.sql +//go:embed sqlc/migration*/*.*.sql var SqlSchemas embed.FS diff --git a/db/sqlc/migrations_dev/000001_dev_test_migration.down.sql b/db/sqlc/migrations_dev/000001_dev_test_migration.down.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_dev_test_migration.down.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/db/sqlc/migrations_dev/000001_dev_test_migration.up.sql b/db/sqlc/migrations_dev/000001_dev_test_migration.up.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_dev_test_migration.up.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh index 16db97f2c..3d93f37ff 100755 --- a/scripts/gen_sqlc_docker.sh +++ b/scripts/gen_sqlc_docker.sh @@ -5,7 +5,7 @@ set -e # restore_files is a function to restore original schema files. restore_files() { echo "Restoring SQLite bigint patch..." - for file in db/sqlc/migrations/*.up.sql.bak; do + for file in db/sqlc/{migrations,migrations_dev}/*.up.sql.bak; do mv "$file" "${file%.bak}" done } @@ -30,7 +30,7 @@ GOMODCACHE=$(go env GOMODCACHE) # source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER # PRIMARY KEY". echo "Applying SQLite bigint patch..." -for file in db/sqlc/migrations/*.up.sql; do +for file in db/sqlc/{migrations,migrations_dev}/*.up.sql; do echo "Patching $file" sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file" done From e231ccf41b86408ef486a58e9d5089d572ef9278 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 9 Jun 2025 14:39:51 +0200 Subject: [PATCH 24/25] accounts: export kvdb DB --- accounts/service.go | 4 ++-- accounts/service_test.go | 2 +- accounts/sql_migration_test.go | 6 +++--- accounts/store_kvdb.go | 22 +++++++++++----------- accounts/test_kvdb.go | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/accounts/service.go b/accounts/service.go index 102b9ea84..697cf7518 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -361,7 +361,7 @@ func (s *InterceptorService) CreditAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Credit the account in the db. + // Credit the account in the DB. err := s.store.CreditAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to credit account: %w", err) @@ -386,7 +386,7 @@ func (s *InterceptorService) DebitAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Debit the account in the db. + // Debit the account in the DB. err := s.store.DebitAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to debit account: %w", err) diff --git a/accounts/service_test.go b/accounts/service_test.go index 1d4388664..f69d0fe93 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -246,7 +246,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was started successfully and // still running though, despite the closing of the - // db store. + // DB store. require.True(t, s.IsRunning()) // Now let's send the invoice update, which should fail. diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index 6697f74e8..e4717531a 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -301,7 +301,7 @@ func TestAccountStoreMigration(t *testing.T) { ) require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, kvStore.db.Close()) + require.NoError(t, kvStore.DB.Close()) }) // Populate the kv store. @@ -336,7 +336,7 @@ func TestAccountStoreMigration(t *testing.T) { err = txEx.ExecTx(ctx, &opts, func(tx SQLMig6Queries) error { return MigrateAccountStoreToSQL( - ctx, kvStore.db, tx, + ctx, kvStore.DB, tx, ) }, sqldb.NoOpReset, ) @@ -440,7 +440,7 @@ func rapidRandomizeAccounts(t *testing.T, kvStore *BoltStore) { acct := makeAccountGen().Draw(t, "account") // Then proceed to insert the account with its invoices and - // payments into the db + // payments into the DB newAcct, err := kvStore.NewAccount( ctx, acct.balance, acct.expiry, acct.label, ) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index a419017a8..c8f0282ee 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -24,7 +24,7 @@ import ( const ( // DBFilename is the filename within the data directory which contains // the macaroon stores. - DBFilename = "accounts.db" + DBFilename = "accounts.DB" // dbPathPermission is the default permission the account database // directory is created with (if it does not exist). @@ -60,7 +60,7 @@ var ( // BoltStore wraps the bolt DB that stores all accounts and their balances. type BoltStore struct { - db kvdb.Backend + DB kvdb.Backend clock clock.Clock } @@ -101,7 +101,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // Return the DB wrapped in a BoltStore object. return &BoltStore{ - db: db, + DB: db, clock: clock, }, nil } @@ -110,7 +110,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // // NOTE: This is part of the Store interface. func (s *BoltStore) Close() error { - return s.db.Close() + return s.DB.Close() } // NewAccount creates a new OffChainBalanceAccount with the given balance and a @@ -162,7 +162,7 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, // Try storing the account in the account database, so we can keep track // of its balance. - err := s.db.Update(func(tx walletdb.ReadWriteTx) error { + err := s.DB.Update(func(tx walletdb.ReadWriteTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -364,7 +364,7 @@ func (s *BoltStore) DeleteAccountPayment(_ context.Context, id AccountID, func (s *BoltStore) updateAccount(id AccountID, updateFn func(*OffChainBalanceAccount) error) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -451,7 +451,7 @@ func (s *BoltStore) Account(_ context.Context, id AccountID) ( // Try looking up and reading the account by its ID from the local // bolt DB. var accountBinary []byte - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -487,7 +487,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, error) { var accounts []*OffChainBalanceAccount - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { // This function will be called in the ForEach and receive // the key and value of each account in the DB. The key, which // is also the ID is not used because it is also marshaled into @@ -531,7 +531,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, // // NOTE: This is part of the Store interface. func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -554,7 +554,7 @@ func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte ) - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -592,7 +592,7 @@ func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, byteOrder.PutUint64(addValue, addIndex) byteOrder.PutUint64(settleValue, settleIndex) - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go index 546c1eee7..1d181928c 100644 --- a/accounts/test_kvdb.go +++ b/accounts/test_kvdb.go @@ -28,7 +28,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, store.db.Close()) + require.NoError(t, store.DB.Close()) }) return store From 78d4b401e6ca5b03dfa9968b25c5cc7bf20c5265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 11 Jul 2025 21:55:59 +0200 Subject: [PATCH 25/25] multi: add kvdb to sql dev migration --- accounts/test_sqlite.go | 6 +- config_dev.go | 15 ++- config_prod.go | 5 +- db/migrations.go | 53 +++++++++ .../post_migration_callbacks_dev.go | 107 ++++++++++++++++++ db/migrationstreams/sql_migrations.go | 13 ++- db/migrationstreams/sql_migrations_dev.go | 41 ++++++- db/postgres.go | 3 +- ...00001_code_migration_kvdb_to_sql.down.sql} | 0 ... 000001_code_migration_kvdb_to_sql.up.sql} | 0 firewalldb/test_sqlite.go | 6 +- session/test_sqlite.go | 6 +- terminal.go | 2 +- 13 files changed, 231 insertions(+), 26 deletions(-) create mode 100644 db/migrationstreams/post_migration_callbacks_dev.go rename db/sqlc/migrations_dev/{000001_dev_test_migration.down.sql => 000001_code_migration_kvdb_to_sql.down.sql} (100%) rename db/sqlc/migrations_dev/{000001_dev_test_migration.up.sql => 000001_code_migration_kvdb_to_sql.up.sql} (100%) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index e2d281e4c..b1e8be871 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/lightninglabs/lightning-terminal/db/migrationstreams" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -19,7 +19,7 @@ var ErrDBClosed = errors.New("database is closed") func NewTestDB(t *testing.T, clock clock.Clock) Store { return createStore( t, - sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, clock, ) } @@ -30,7 +30,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { tDb := sqldb.NewTestSqliteDBFromPath( - t, dbPath, migrationstreams.LitdMigrationStreams, + t, dbPath, db.MakeTestMigrationStreams(), ) return createStore(t, tDb.BaseDB, clock) diff --git a/config_dev.go b/config_dev.go index d40acaa2e..a2856979e 100644 --- a/config_dev.go +++ b/config_dev.go @@ -3,6 +3,7 @@ package terminal import ( + "context" "fmt" "path/filepath" @@ -87,7 +88,9 @@ func defaultDevConfig() *DevConfig { } // NewStores creates a new stores instance based on the chosen database backend. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(ctx context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + var ( networkDir = filepath.Join(cfg.LitDir, cfg.Network) stores = &stores{ @@ -114,7 +117,10 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if !cfg.Sqlite.SkipMigrations { err = sqldb.ApplyAllMigrations( - sqlStore, migrationstreams.LitdMigrationStreams, + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), ) if err != nil { return stores, fmt.Errorf("error applying "+ @@ -156,7 +162,10 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if !cfg.Postgres.SkipMigrations { err = sqldb.ApplyAllMigrations( - sqlStore, migrationstreams.LitdMigrationStreams, + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), ) if err != nil { return stores, fmt.Errorf("error applying "+ diff --git a/config_prod.go b/config_prod.go index ac6e6d996..1f11507a6 100644 --- a/config_prod.go +++ b/config_prod.go @@ -3,6 +3,7 @@ package terminal import ( + "context" "fmt" "path/filepath" @@ -29,7 +30,9 @@ func (c *DevConfig) Validate(_, _ string) error { // NewStores creates a new instance of the stores struct using the default Bolt // backend since in production, this is currently the only backend supported. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(_ context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + networkDir := filepath.Join(cfg.LitDir, cfg.Network) stores := &stores{ diff --git a/db/migrations.go b/db/migrations.go index 87811330f..64d7b7d9b 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -12,8 +12,10 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/lightninglabs/taproot-assets/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -288,3 +290,54 @@ func (t *replacerFile) Close() error { // instance, so there's nothing to do for us here. return nil } + +// MakeTestMigrationStreams creates the migration streams for the unit test +// environment. +// +// NOTE: This function is not located in the migrationstreams package to avoid +// cyclic dependencies. This test migration stream does not run the kvdb to sql +// migration, as we already have separate unit tests which tests the migration. +func MakeTestMigrationStreams() []sqldb.MigrationStream { + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + // IN DEV CASE: + migStreamDev := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + return []sqldb.MigrationStream{migStream, migStreamDev} +} diff --git a/db/migrationstreams/post_migration_callbacks_dev.go b/db/migrationstreams/post_migration_callbacks_dev.go new file mode 100644 index 000000000..7bfe5dbfb --- /dev/null +++ b/db/migrationstreams/post_migration_callbacks_dev.go @@ -0,0 +1,107 @@ +//go:build dev + +package migrationstreams + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightninglabs/lightning-terminal/firewalldb" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// MakePostStepCallbacksMig6 turns the post migration checks into a map of post +// step callbacks that can be used with the migrate package. The keys of the map +// are the migration versions, and the values are the callbacks that will be +// executed after the migration with the corresponding version is applied. +func MakePostStepCallbacksMig6(ctx context.Context, db *sqldb.BaseDB, + macPath string, clock clock.Clock, + migVersion uint) migrate.PostStepCallback { + + mig6queries := sqlcmig6.NewForType(db, db.BackendType) + mig6executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlcmig6.Queries { + return mig6queries.WithTx(tx) + }, + ) + + return func(_ *migrate.Migration, _ database.Driver) error { + // We ignore the actual driver that's being returned here, since + // we use migrate.NewWithInstance() to create the migration + // instance from our already instantiated database backend that + // is also passed into this function. + return mig6executor.ExecTx( + ctx, sqldb.NewWriteTx(), + func(q6 *sqlcmig6.Queries) error { + log.Infof("Running post migration callback "+ + "for migration version %d", migVersion) + + return kvdbToSqlMigrationCallback( + ctx, macPath, db, clock, q6, + ) + }, sqldb.NoOpReset, + ) + } +} + +func kvdbToSqlMigrationCallback(ctx context.Context, macPath string, + _ *sqldb.BaseDB, clock clock.Clock, q *sqlcmig6.Queries) error { + + start := time.Now() + log.Infof("Starting KVDB to SQL migration for all stores") + + accountStore, err := accounts.NewBoltStore( + filepath.Dir(macPath), accounts.DBFilename, clock, + ) + if err != nil { + return err + } + + err = accounts.MigrateAccountStoreToSQL(ctx, accountStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating account store to "+ + "SQL: %v", err) + } + + sessionStore, err := session.NewDB( + filepath.Dir(macPath), session.DBFilename, + clock, accountStore, + ) + if err != nil { + return err + } + + err = session.MigrateSessionStoreToSQL(ctx, sessionStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating session store to "+ + "SQL: %v", err) + } + + firewallStore, err := firewalldb.NewBoltDB( + filepath.Dir(macPath), firewalldb.DBFilename, + sessionStore, accountStore, clock, + ) + if err != nil { + return err + } + + err = firewalldb.MigrateFirewallDBToSQL(ctx, firewallStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating firewalldb store "+ + "to SQL: %v", err) + } + + log.Infof("Succesfully migrated all KVDB stores to SQL in: %v", + time.Since(start)) + + return nil +} diff --git a/db/migrationstreams/sql_migrations.go b/db/migrationstreams/sql_migrations.go index 5bcf8ba76..ca135cf4f 100644 --- a/db/migrationstreams/sql_migrations.go +++ b/db/migrationstreams/sql_migrations.go @@ -3,14 +3,18 @@ package migrationstreams import ( + "context" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) -var ( - LitdMigrationStream = sqldb.MigrationStream{ +func MakeMigrationStreams(_ context.Context, _ string, + _ clock.Clock) []sqldb.MigrationStream { + + migStream := sqldb.MigrationStream{ MigrateTableName: pgx.DefaultMigrationsTable, SQLFileDirectory: "sqlc/migrations", Schemas: db.SqlSchemas, @@ -29,5 +33,6 @@ var ( return make(map[uint]migrate.PostStepCallback), nil }, } - LitdMigrationStreams = []sqldb.MigrationStream{LitdMigrationStream} -) + + return []sqldb.MigrationStream{migStream} +} diff --git a/db/migrationstreams/sql_migrations_dev.go b/db/migrationstreams/sql_migrations_dev.go index bc32b39f0..c2ebeeb10 100644 --- a/db/migrationstreams/sql_migrations_dev.go +++ b/db/migrationstreams/sql_migrations_dev.go @@ -3,15 +3,29 @@ package migrationstreams import ( + "context" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) -var ( +const ( + // KVDBtoSQLMigVersion is the version of the migration that migrates the + // kvdb to the sql database. + // + // TODO: When this the kvdb to sql migration goes live into prod, this + // should be moved to non dev db/migrations.go file, and this constant + // value should be updated to reflect the real migration number. + KVDBtoSQLMigVersion = 1 +) + +func MakeMigrationStreams(ctx context.Context, macPath string, + clock clock.Clock) []sqldb.MigrationStream { + // Create the prod migration stream. - migStream = sqldb.MigrationStream{ + migStream := sqldb.MigrationStream{ MigrateTableName: pgx.DefaultMigrationsTable, SQLFileDirectory: "sqlc/migrations", Schemas: db.SqlSchemas, @@ -32,7 +46,7 @@ var ( } // Create the dev migration stream. - migStreamDev = sqldb.MigrationStream{ + migStreamDev := sqldb.MigrationStream{ MigrateTableName: pgx.DefaultMigrationsTable + "_dev", SQLFileDirectory: "sqlc/migrations_dev", Schemas: db.SqlSchemas, @@ -48,8 +62,23 @@ var ( db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, error) { - return make(map[uint]migrate.PostStepCallback), nil + // Any Callbacks added to this map will be executed when + // after the dev migration number for the uint key in + // the map has been applied. If no entry exists for a + // given uint, then no callback will be executed for + // that migration number. This is useful for adding a + // code migration step as a callback to be run + // after a specific migration of a given number has been + // applied. + res := make(map[uint]migrate.PostStepCallback) + + res[KVDBtoSQLMigVersion] = MakePostStepCallbacksMig6( + ctx, db, macPath, clock, KVDBtoSQLMigVersion, + ) + + return res, nil }, } - LitdMigrationStreams = []sqldb.MigrationStream{migStream, migStreamDev} -) + + return []sqldb.MigrationStream{migStream, migStreamDev} +} diff --git a/db/postgres.go b/db/postgres.go index 546f6311d..2edfe0576 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -8,7 +8,6 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" @@ -170,7 +169,7 @@ func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { }) return sqldb.NewTestPostgresDB( - t, sqlFixture, migrationstreams.LitdMigrationStreams, + t, sqlFixture, MakeTestMigrationStreams(), ) } diff --git a/db/sqlc/migrations_dev/000001_dev_test_migration.down.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql similarity index 100% rename from db/sqlc/migrations_dev/000001_dev_test_migration.down.sql rename to db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql diff --git a/db/sqlc/migrations_dev/000001_dev_test_migration.up.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql similarity index 100% rename from db/sqlc/migrations_dev/000001_dev_test_migration.up.sql rename to db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 844833a1e..3f91546af 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -5,7 +5,7 @@ package firewalldb import ( "testing" - "github.com/lightninglabs/lightning-terminal/db/migrationstreams" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -14,7 +14,7 @@ import ( func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return createStore( t, - sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, clock, ) } @@ -25,7 +25,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) FirewallDBs { tDb := sqldb.NewTestSqliteDBFromPath( - t, dbPath, migrationstreams.LitdMigrationStreams, + t, dbPath, db.MakeTestMigrationStreams(), ) return createStore(t, tDb.BaseDB, clock) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 1c3702444..c9dbc5934 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/lightninglabs/lightning-terminal/db/migrationstreams" + "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/sqldb/v2" ) @@ -19,7 +19,7 @@ var ErrDBClosed = errors.New("database is closed") func NewTestDB(t *testing.T, clock clock.Clock) Store { return createStore( t, - sqldb.NewTestSqliteDB(t, migrationstreams.LitdMigrationStreams).BaseDB, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, clock, ) } @@ -30,7 +30,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { tDb := sqldb.NewTestSqliteDBFromPath( - t, dbPath, migrationstreams.LitdMigrationStreams, + t, dbPath, db.MakeTestMigrationStreams(), ) return createStore(t, tDb.BaseDB, clock) diff --git a/terminal.go b/terminal.go index 7e4d552c7..5f8fa2efb 100644 --- a/terminal.go +++ b/terminal.go @@ -447,7 +447,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("could not create network directory: %v", err) } - g.stores, err = NewStores(g.cfg, clock.NewDefaultClock()) + g.stores, err = NewStores(ctx, g.cfg, clock.NewDefaultClock()) if err != nil { return fmt.Errorf("could not create stores: %v", err) }