diff --git a/accounts/service.go b/accounts/service.go index 102b9ea84..697cf7518 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -361,7 +361,7 @@ func (s *InterceptorService) CreditAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Credit the account in the db. + // Credit the account in the DB. err := s.store.CreditAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to credit account: %w", err) @@ -386,7 +386,7 @@ func (s *InterceptorService) DebitAccount(ctx context.Context, return nil, ErrAccountServiceDisabled } - // Debit the account in the db. + // Debit the account in the DB. err := s.store.DebitAccount(ctx, accountID, amount) if err != nil { return nil, fmt.Errorf("unable to debit account: %w", err) diff --git a/accounts/service_test.go b/accounts/service_test.go index 1d4388664..f69d0fe93 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -246,7 +246,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was started successfully and // still running though, despite the closing of the - // db store. + // DB store. require.True(t, s.IsRunning()) // Now let's send the invoice update, which should fail. diff --git a/accounts/sql_migration.go b/accounts/sql_migration.go index c36b51c6f..94f88ac28 100644 --- a/accounts/sql_migration.go +++ b/accounts/sql_migration.go @@ -11,8 +11,11 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" "github.com/pmezard/go-difflib/difflib" ) @@ -27,7 +30,7 @@ var ( // the KV database to the SQL database. The migration is done in a single // transaction to ensure that all accounts are migrated or none at all. func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV accounts store to SQL") @@ -50,7 +53,7 @@ func MigrateAccountStoreToSQL(ctx context.Context, kvStore kvdb.Backend, // to the SQL database. The migration is done in a single transaction to ensure // that all accounts are migrated or none at all. func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts from KV to SQL") @@ -68,7 +71,7 @@ func migrateAccountsToSQL(ctx context.Context, kvStore kvdb.Backend, kvAccount.ID, err) } - migratedAccount, err := getAndMarshalAccount( + migratedAccount, err := getAndMarshalMig6Account( ctx, tx, migratedAccountID, ) if err != nil { @@ -151,17 +154,79 @@ func getBBoltAccounts(db kvdb.Backend) ([]*OffChainBalanceAccount, error) { return accounts, nil } +// getAndMarshalAccount retrieves the account with the given ID. If the account +// cannot be found, then ErrAccNotFound is returned. +func getAndMarshalMig6Account(ctx context.Context, db SQLMig6Queries, + id int64) (*OffChainBalanceAccount, error) { + + dbAcct, err := db.GetAccount(ctx, id) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccNotFound + } else if err != nil { + return nil, err + } + + return marshalDBMig6Account(ctx, db, dbAcct) +} + +func marshalDBMig6Account(ctx context.Context, db SQLMig6Queries, + dbAcct sqlcmig6.Account) (*OffChainBalanceAccount, error) { + + alias, err := AccountIDFromInt64(dbAcct.Alias) + if err != nil { + return nil, err + } + + account := &OffChainBalanceAccount{ + ID: alias, + Type: AccountType(dbAcct.Type), + InitialBalance: lnwire.MilliSatoshi(dbAcct.InitialBalanceMsat), + CurrentBalance: dbAcct.CurrentBalanceMsat, + LastUpdate: dbAcct.LastUpdated.UTC(), + ExpirationDate: dbAcct.Expiration.UTC(), + Invoices: make(AccountInvoices), + Payments: make(AccountPayments), + Label: dbAcct.Label.String, + } + + invoices, err := db.ListAccountInvoices(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + for _, invoice := range invoices { + var hash lntypes.Hash + copy(hash[:], invoice.Hash) + account.Invoices[hash] = struct{}{} + } + + payments, err := db.ListAccountPayments(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + + for _, payment := range payments { + var hash lntypes.Hash + copy(hash[:], payment.Hash) + account.Payments[hash] = &PaymentEntry{ + Status: lnrpc.Payment_PaymentStatus(payment.Status), + FullAmount: lnwire.MilliSatoshi(payment.FullAmountMsat), + } + } + + return account, nil +} + // migrateSingleAccountToSQL runs the migration for a single account from the // KV database to the SQL database. func migrateSingleAccountToSQL(ctx context.Context, - tx SQLQueries, account *OffChainBalanceAccount) (int64, error) { + tx SQLMig6Queries, account *OffChainBalanceAccount) (int64, error) { accountAlias, err := account.ID.ToInt64() if err != nil { return 0, err } - insertAccountParams := sqlc.InsertAccountParams{ + insertAccountParams := sqlcmig6.InsertAccountParams{ Type: int16(account.Type), InitialBalanceMsat: int64(account.InitialBalance), CurrentBalanceMsat: account.CurrentBalance, @@ -180,7 +245,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash := range account.Invoices { - addInvoiceParams := sqlc.AddAccountInvoiceParams{ + addInvoiceParams := sqlcmig6.AddAccountInvoiceParams{ AccountID: sqlId, Hash: hash[:], } @@ -192,7 +257,7 @@ func migrateSingleAccountToSQL(ctx context.Context, } for hash, paymentEntry := range account.Payments { - upsertPaymentParams := sqlc.UpsertAccountPaymentParams{ + upsertPaymentParams := sqlcmig6.UpsertAccountPaymentParams{ AccountID: sqlId, Hash: hash[:], Status: int16(paymentEntry.Status), @@ -211,7 +276,7 @@ func migrateSingleAccountToSQL(ctx context.Context, // migrateAccountsIndicesToSQL runs the migration for the account indices from // the KV database to the SQL database. func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of accounts indices from KV to SQL") @@ -233,7 +298,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, settleIndexName, settleIndex) } - setAddIndexParams := sqlc.SetAccountIndexParams{ + setAddIndexParams := sqlcmig6.SetAccountIndexParams{ Name: addIndexName, Value: int64(addIndex), } @@ -243,7 +308,7 @@ func migrateAccountsIndicesToSQL(ctx context.Context, kvStore kvdb.Backend, return err } - setSettleIndexParams := sqlc.SetAccountIndexParams{ + setSettleIndexParams := sqlcmig6.SetAccountIndexParams{ Name: settleIndexName, Value: int64(settleIndex), } diff --git a/accounts/sql_migration_test.go b/accounts/sql_migration_test.go index bfa508df5..e4717531a 100644 --- a/accounts/sql_migration_test.go +++ b/accounts/sql_migration_test.go @@ -2,18 +2,17 @@ package accounts import ( "context" - "database/sql" "fmt" "testing" "time" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "pgregory.net/rapid" @@ -36,7 +35,7 @@ func TestAccountStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { testDBStore := NewTestDB(t, clock) @@ -45,13 +44,9 @@ func TestAccountStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } assertMigrationResults := func(t *testing.T, sqlStore *SQLStore, @@ -306,7 +301,7 @@ func TestAccountStoreMigration(t *testing.T) { ) require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, kvStore.db.Close()) + require.NoError(t, kvStore.DB.Close()) }) // Populate the kv store. @@ -339,11 +334,11 @@ func TestAccountStoreMigration(t *testing.T) { // Perform the migration. var opts sqldb.MigrationTxOptions err = txEx.ExecTx(ctx, &opts, - func(tx SQLQueries) error { + func(tx SQLMig6Queries) error { return MigrateAccountStoreToSQL( - ctx, kvStore.db, tx, + ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) @@ -445,7 +440,7 @@ func rapidRandomizeAccounts(t *testing.T, kvStore *BoltStore) { acct := makeAccountGen().Draw(t, "account") // Then proceed to insert the account with its invoices and - // payments into the db + // payments into the DB newAcct, err := kvStore.NewAccount( ctx, acct.balance, acct.expiry, acct.label, ) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index a419017a8..c8f0282ee 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -24,7 +24,7 @@ import ( const ( // DBFilename is the filename within the data directory which contains // the macaroon stores. - DBFilename = "accounts.db" + DBFilename = "accounts.DB" // dbPathPermission is the default permission the account database // directory is created with (if it does not exist). @@ -60,7 +60,7 @@ var ( // BoltStore wraps the bolt DB that stores all accounts and their balances. type BoltStore struct { - db kvdb.Backend + DB kvdb.Backend clock clock.Clock } @@ -101,7 +101,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // Return the DB wrapped in a BoltStore object. return &BoltStore{ - db: db, + DB: db, clock: clock, }, nil } @@ -110,7 +110,7 @@ func NewBoltStore(dir, fileName string, clock clock.Clock) (*BoltStore, error) { // // NOTE: This is part of the Store interface. func (s *BoltStore) Close() error { - return s.db.Close() + return s.DB.Close() } // NewAccount creates a new OffChainBalanceAccount with the given balance and a @@ -162,7 +162,7 @@ func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, // Try storing the account in the account database, so we can keep track // of its balance. - err := s.db.Update(func(tx walletdb.ReadWriteTx) error { + err := s.DB.Update(func(tx walletdb.ReadWriteTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -364,7 +364,7 @@ func (s *BoltStore) DeleteAccountPayment(_ context.Context, id AccountID, func (s *BoltStore) updateAccount(id AccountID, updateFn func(*OffChainBalanceAccount) error) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -451,7 +451,7 @@ func (s *BoltStore) Account(_ context.Context, id AccountID) ( // Try looking up and reading the account by its ID from the local // bolt DB. var accountBinary []byte - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -487,7 +487,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, error) { var accounts []*OffChainBalanceAccount - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { // This function will be called in the ForEach and receive // the key and value of each account in the DB. The key, which // is also the ID is not used because it is also marshaled into @@ -531,7 +531,7 @@ func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, // // NOTE: This is part of the Store interface. func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -554,7 +554,7 @@ func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte ) - err := s.db.View(func(tx kvdb.RTx) error { + err := s.DB.View(func(tx kvdb.RTx) error { bucket := tx.ReadBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound @@ -592,7 +592,7 @@ func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, byteOrder.PutUint64(addValue, addIndex) byteOrder.PutUint64(settleValue, settleIndex) - return s.db.Update(func(tx kvdb.RwTx) error { + return s.DB.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 830f16587..13422315c 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -11,11 +11,13 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -33,6 +35,8 @@ const ( // //nolint:lll type SQLQueries interface { + sqldb.BaseQuerier + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error @@ -53,12 +57,40 @@ type SQLQueries interface { GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) } -// BatchedSQLQueries is a version of the SQLQueries that's capable -// of batched database operations. +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with accounts related tables. +// +//nolint:lll +type SQLMig6Queries interface { + sqldb.BaseQuerier + + AddAccountInvoice(ctx context.Context, arg sqlcmig6.AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg sqlcmig6.DeleteAccountPaymentParams) error + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountPayment(ctx context.Context, arg sqlcmig6.GetAccountPaymentParams) (sqlcmig6.AccountPayment, error) + InsertAccount(ctx context.Context, arg sqlcmig6.InsertAccountParams) (int64, error) + ListAccountInvoices(ctx context.Context, id int64) ([]sqlcmig6.AccountInvoice, error) + ListAccountPayments(ctx context.Context, id int64) ([]sqlcmig6.AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]sqlcmig6.Account, error) + SetAccountIndex(ctx context.Context, arg sqlcmig6.SetAccountIndexParams) error + UpdateAccountBalance(ctx context.Context, arg sqlcmig6.UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg sqlcmig6.UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg sqlcmig6.UpdateAccountLastUpdateParams) (int64, error) + UpsertAccountPayment(ctx context.Context, arg sqlcmig6.UpsertAccountPaymentParams) error + GetAccountInvoice(ctx context.Context, arg sqlcmig6.GetAccountInvoiceParams) (sqlcmig6.AccountInvoice, error) +} + +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -68,19 +100,57 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) }, ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -157,7 +227,7 @@ func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -299,7 +369,7 @@ func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, acctID) - }) + }, sqldb.NoOpReset) } func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( @@ -377,7 +447,7 @@ func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // CreditAccount increases the balance of the account with the given alias by @@ -412,7 +482,7 @@ func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DebitAccount decreases the balance of the account with the given alias by the @@ -453,7 +523,7 @@ func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // Account retrieves an account from the SQL store and un-marshals it. If the @@ -475,7 +545,7 @@ func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( account, err = getAndMarshalAccount(ctx, db, id) return err - }) + }, sqldb.NoOpReset) return account, err } @@ -507,7 +577,7 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, } return nil - }) + }, sqldb.NoOpReset) return accounts, err } @@ -524,7 +594,7 @@ func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { } return db.DeleteAccount(ctx, id) - }) + }, sqldb.NoOpReset) } // UpsertAccountPayment updates or inserts a payment entry for the given @@ -634,7 +704,7 @@ func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // DeleteAccountPayment removes a payment entry from the account with the given @@ -677,7 +747,7 @@ func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, } return s.markAccountUpdated(ctx, db, id) - }) + }, sqldb.NoOpReset) } // LastIndexes returns the last invoice add and settle index or @@ -704,7 +774,7 @@ func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { } return err - }) + }, sqldb.NoOpReset) return uint64(addIndex), uint64(settleIndex), err } @@ -729,7 +799,7 @@ func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, Name: settleIndexName, Value: int64(settleIndex), }) - }) + }, sqldb.NoOpReset) } // Close closes the underlying store. diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go index 546c1eee7..1d181928c 100644 --- a/accounts/test_kvdb.go +++ b/accounts/test_kvdb.go @@ -28,7 +28,7 @@ func NewTestDBFromPath(t *testing.T, dbPath string, require.NoError(t, err) t.Cleanup(func() { - require.NoError(t, store.db.Close()) + require.NoError(t, store.DB.Close()) }) return store diff --git a/accounts/test_sql.go b/accounts/test_sql.go index 3c1ee7f16..ca2f43d6f 100644 --- a/accounts/test_sql.go +++ b/accounts/test_sql.go @@ -5,15 +5,20 @@ package accounts import ( "testing" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go index 9d899b3e2..b1e8be871 100644 --- a/accounts/test_sqlite.go +++ b/accounts/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,11 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/config_dev.go b/config_dev.go index 90b8b290f..a2856979e 100644 --- a/config_dev.go +++ b/config_dev.go @@ -3,14 +3,18 @@ package terminal import ( + "context" "fmt" "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -84,7 +88,9 @@ func defaultDevConfig() *DevConfig { } // NewStores creates a new stores instance based on the chosen database backend. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(ctx context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + var ( networkDir = filepath.Join(cfg.LitDir, cfg.Network) stores = &stores{ @@ -101,14 +107,39 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { return stores, err } - sqlStore, err := db.NewSqliteStore(cfg.Sqlite) + sqlStore, err := sqldb.NewSqliteStore(&sqldb.SqliteConfig{ + SkipMigrations: cfg.Sqlite.SkipMigrations, + SkipMigrationDbBackup: cfg.Sqlite.SkipMigrationDbBackup, + }, cfg.Sqlite.DatabaseFileName) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Sqlite.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to SQLlite store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore @@ -116,14 +147,44 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: - sqlStore, err := db.NewPostgresStore(cfg.Postgres) + sqlStore, err := sqldb.NewPostgresStore(&sqldb.PostgresConfig{ + Dsn: cfg.Postgres.DSN(false), + MaxOpenConnections: cfg.Postgres.MaxOpenConnections, + MaxIdleConnections: cfg.Postgres.MaxIdleConnections, + ConnMaxLifetime: cfg.Postgres.ConnMaxLifetime, + ConnMaxIdleTime: cfg.Postgres.ConnMaxIdleTime, + RequireSSL: cfg.Postgres.RequireSSL, + SkipMigrations: cfg.Postgres.SkipMigrations, + }) if err != nil { return stores, err } - acctStore := accounts.NewSQLStore(sqlStore.BaseDB, clock) - sessStore := session.NewSQLStore(sqlStore.BaseDB, clock) - firewallStore := firewalldb.NewSQLDB(sqlStore.BaseDB, clock) + if !cfg.Postgres.SkipMigrations { + err = sqldb.ApplyAllMigrations( + sqlStore, + migrationstreams.MakeMigrationStreams( + ctx, cfg.MacaroonPath, clock, + ), + ) + if err != nil { + return stores, fmt.Errorf("error applying "+ + "migrations to Postgres store: %w", err, + ) + } + } + + queries := sqlc.NewForType(sqlStore, sqlStore.BackendType) + + acctStore := accounts.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + sessStore := session.NewSQLStore( + sqlStore.BaseDB, queries, clock, + ) + firewallStore := firewalldb.NewSQLDB( + sqlStore.BaseDB, queries, clock, + ) stores.accounts = acctStore stores.sessions = sessStore diff --git a/config_prod.go b/config_prod.go index ac6e6d996..1f11507a6 100644 --- a/config_prod.go +++ b/config_prod.go @@ -3,6 +3,7 @@ package terminal import ( + "context" "fmt" "path/filepath" @@ -29,7 +30,9 @@ func (c *DevConfig) Validate(_, _ string) error { // NewStores creates a new instance of the stores struct using the default Bolt // backend since in production, this is currently the only backend supported. -func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { +func NewStores(_ context.Context, cfg *Config, + clock clock.Clock) (*stores, error) { + networkDir := filepath.Join(cfg.LitDir, cfg.Network) stores := &stores{ diff --git a/db/interfaces.go b/db/interfaces.go index ba64520b4..bb39df9ea 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" ) var ( @@ -56,7 +57,7 @@ type BatchedTx[Q any] interface { txBody func(Q) error) error // Backend returns the type of the database backend used. - Backend() sqlc.BackendType + Backend() sqldb.BackendType } // Tx represents a database transaction that can be committed or rolled back. @@ -277,7 +278,7 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context, } // Backend returns the type of the database backend used. -func (t *TransactionExecutor[Q]) Backend() sqlc.BackendType { +func (t *TransactionExecutor[Q]) Backend() sqldb.BackendType { return t.BatchedQuerier.Backend() } @@ -301,7 +302,7 @@ func (s *BaseDB) BeginTx(ctx context.Context, opts TxOptions) (*sql.Tx, error) { } // Backend returns the type of the database backend used. -func (s *BaseDB) Backend() sqlc.BackendType { +func (s *BaseDB) Backend() sqldb.BackendType { return s.Queries.Backend() } diff --git a/db/migrations.go b/db/migrations.go index 79d63587e..64d7b7d9b 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -12,8 +12,10 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" "github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/lightninglabs/taproot-assets/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) const ( @@ -23,6 +25,15 @@ const ( // // NOTE: This MUST be updated when a new migration is added. LatestMigrationVersion = 5 + + // LatestDevMigrationVersion is the latest dev migration version of the + // database. This is used to implement downgrade protection for the + // daemon. This represents the latest number used in the migrations_dev + // directory. + // + // NOTE: This MUST be updated when a migration is added or removed, from + // the migrations_dev directory. + LatestDevMigrationVersion = 1 ) // MigrationTarget is a functional option that can be passed to applyMigrations @@ -120,7 +131,7 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, targetVersion MigrationTarget, opts *migrateOptions) error { // With the migrate instance open, we'll create a new migration source - // using the embedded file system stored in sqlSchemas. The library + // using the embedded file system stored in SqlSchemas. The library // we're using can't handle a raw file system interface, so we wrap it // in this intermediate layer. migrateFileServer, err := httpfs.New(http.FS(fs), path) @@ -279,3 +290,54 @@ func (t *replacerFile) Close() error { // instance, so there's nothing to do for us here. return nil } + +// MakeTestMigrationStreams creates the migration streams for the unit test +// environment. +// +// NOTE: This function is not located in the migrationstreams package to avoid +// cyclic dependencies. This test migration stream does not run the kvdb to sql +// migration, as we already have separate unit tests which tests the migration. +func MakeTestMigrationStreams() []sqldb.MigrationStream { + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + // IN DEV CASE: + migStreamDev := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + return []sqldb.MigrationStream{migStream, migStreamDev} +} diff --git a/db/migrationstreams/log.go b/db/migrationstreams/log.go new file mode 100644 index 000000000..0a3e80731 --- /dev/null +++ b/db/migrationstreams/log.go @@ -0,0 +1,25 @@ +package migrationstreams + +import ( + "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/build" +) + +const Subsystem = "MIGS" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/db/migrationstreams/post_migration_callbacks_dev.go b/db/migrationstreams/post_migration_callbacks_dev.go new file mode 100644 index 000000000..7bfe5dbfb --- /dev/null +++ b/db/migrationstreams/post_migration_callbacks_dev.go @@ -0,0 +1,107 @@ +//go:build dev + +package migrationstreams + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightninglabs/lightning-terminal/firewalldb" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// MakePostStepCallbacksMig6 turns the post migration checks into a map of post +// step callbacks that can be used with the migrate package. The keys of the map +// are the migration versions, and the values are the callbacks that will be +// executed after the migration with the corresponding version is applied. +func MakePostStepCallbacksMig6(ctx context.Context, db *sqldb.BaseDB, + macPath string, clock clock.Clock, + migVersion uint) migrate.PostStepCallback { + + mig6queries := sqlcmig6.NewForType(db, db.BackendType) + mig6executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) *sqlcmig6.Queries { + return mig6queries.WithTx(tx) + }, + ) + + return func(_ *migrate.Migration, _ database.Driver) error { + // We ignore the actual driver that's being returned here, since + // we use migrate.NewWithInstance() to create the migration + // instance from our already instantiated database backend that + // is also passed into this function. + return mig6executor.ExecTx( + ctx, sqldb.NewWriteTx(), + func(q6 *sqlcmig6.Queries) error { + log.Infof("Running post migration callback "+ + "for migration version %d", migVersion) + + return kvdbToSqlMigrationCallback( + ctx, macPath, db, clock, q6, + ) + }, sqldb.NoOpReset, + ) + } +} + +func kvdbToSqlMigrationCallback(ctx context.Context, macPath string, + _ *sqldb.BaseDB, clock clock.Clock, q *sqlcmig6.Queries) error { + + start := time.Now() + log.Infof("Starting KVDB to SQL migration for all stores") + + accountStore, err := accounts.NewBoltStore( + filepath.Dir(macPath), accounts.DBFilename, clock, + ) + if err != nil { + return err + } + + err = accounts.MigrateAccountStoreToSQL(ctx, accountStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating account store to "+ + "SQL: %v", err) + } + + sessionStore, err := session.NewDB( + filepath.Dir(macPath), session.DBFilename, + clock, accountStore, + ) + if err != nil { + return err + } + + err = session.MigrateSessionStoreToSQL(ctx, sessionStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating session store to "+ + "SQL: %v", err) + } + + firewallStore, err := firewalldb.NewBoltDB( + filepath.Dir(macPath), firewalldb.DBFilename, + sessionStore, accountStore, clock, + ) + if err != nil { + return err + } + + err = firewalldb.MigrateFirewallDBToSQL(ctx, firewallStore.DB, q) + if err != nil { + return fmt.Errorf("error migrating firewalldb store "+ + "to SQL: %v", err) + } + + log.Infof("Succesfully migrated all KVDB stores to SQL in: %v", + time.Since(start)) + + return nil +} diff --git a/db/migrationstreams/sql_migrations.go b/db/migrationstreams/sql_migrations.go new file mode 100644 index 000000000..ca135cf4f --- /dev/null +++ b/db/migrationstreams/sql_migrations.go @@ -0,0 +1,38 @@ +//go:build !dev + +package migrationstreams + +import ( + "context" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +func MakeMigrationStreams(_ context.Context, _ string, + _ clock.Clock) []sqldb.MigrationStream { + + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: db.LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + return []sqldb.MigrationStream{migStream} +} diff --git a/db/migrationstreams/sql_migrations_dev.go b/db/migrationstreams/sql_migrations_dev.go new file mode 100644 index 000000000..c2ebeeb10 --- /dev/null +++ b/db/migrationstreams/sql_migrations_dev.go @@ -0,0 +1,84 @@ +//go:build dev + +package migrationstreams + +import ( + "context" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/pgx/v5" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +const ( + // KVDBtoSQLMigVersion is the version of the migration that migrates the + // kvdb to the sql database. + // + // TODO: When this the kvdb to sql migration goes live into prod, this + // should be moved to non dev db/migrations.go file, and this constant + // value should be updated to reflect the real migration number. + KVDBtoSQLMigVersion = 1 +) + +func MakeMigrationStreams(ctx context.Context, macPath string, + clock clock.Clock) []sqldb.MigrationStream { + + // Create the prod migration stream. + migStream := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable, + SQLFileDirectory: "sqlc/migrations", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // database. This is used to implement downgrade protection for + // the daemon. + // + // NOTE: This MUST be updated when a new migration is added. + LatestMigrationVersion: db.LatestMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + return make(map[uint]migrate.PostStepCallback), nil + }, + } + + // Create the dev migration stream. + migStreamDev := sqldb.MigrationStream{ + MigrateTableName: pgx.DefaultMigrationsTable + "_dev", + SQLFileDirectory: "sqlc/migrations_dev", + Schemas: db.SqlSchemas, + + // LatestMigrationVersion is the latest migration version of the + // dev migrations database. This is used to implement downgrade + // protection for the daemon. + // + // NOTE: This MUST be updated when a new dev migration is added. + LatestMigrationVersion: db.LatestDevMigrationVersion, + + MakePostMigrationChecks: func( + db *sqldb.BaseDB) (map[uint]migrate.PostStepCallback, + error) { + + // Any Callbacks added to this map will be executed when + // after the dev migration number for the uint key in + // the map has been applied. If no entry exists for a + // given uint, then no callback will be executed for + // that migration number. This is useful for adding a + // code migration step as a callback to be run + // after a specific migration of a given number has been + // applied. + res := make(map[uint]migrate.PostStepCallback) + + res[KVDBtoSQLMigVersion] = MakePostStepCallbacksMig6( + ctx, db, macPath, clock, KVDBtoSQLMigVersion, + ) + + return res, nil + }, + } + + return []sqldb.MigrationStream{migStream, migStreamDev} +} diff --git a/db/postgres.go b/db/postgres.go index 16e41dc09..2edfe0576 100644 --- a/db/postgres.go +++ b/db/postgres.go @@ -9,6 +9,7 @@ import ( postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -119,7 +120,7 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDb.SetConnMaxLifetime(connMaxLifetime) rawDb.SetConnMaxIdleTime(connMaxIdleTime) - queries := sqlc.NewPostgres(rawDb) + queries := sqlc.NewForType(rawDb, sqldb.BackendTypePostgres) s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -128,15 +129,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { }, } - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(TargetLatest); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - return s, nil } @@ -157,7 +149,7 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, return fmt.Errorf("error creating postgres migration: %w", err) } - postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) + postgresFS := newReplacerFS(SqlSchemas, postgresSchemaReplacements) return applyMigrations( postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target, opts, @@ -166,20 +158,19 @@ func (s *PostgresStore) ExecuteMigrations(target MigrationTarget, // NewTestPostgresDB is a helper function that creates a Postgres database for // testing. -func NewTestPostgresDB(t *testing.T) *PostgresStore { +func NewTestPostgresDB(t *testing.T) *sqldb.PostgresStore { t.Helper() t.Logf("Creating new Postgres DB for testing") - sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true) - store, err := NewPostgresStore(sqlFixture.GetConfig()) - require.NoError(t, err) - + sqlFixture := sqldb.NewTestPgFixture(t, DefaultPostgresFixtureLifetime) t.Cleanup(func() { sqlFixture.TearDown(t) }) - return store + return sqldb.NewTestPostgresDB( + t, sqlFixture, MakeTestMigrationStreams(), + ) } // NewTestPostgresDBWithVersion is a helper function that creates a Postgres diff --git a/db/schemas.go b/db/schemas.go index 1a7a2096f..565fb5615 100644 --- a/db/schemas.go +++ b/db/schemas.go @@ -5,5 +5,5 @@ import ( _ "embed" ) -//go:embed sqlc/migrations/*.*.sql -var sqlSchemas embed.FS +//go:embed sqlc/migration*/*.*.sql +var SqlSchemas embed.FS diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f4bf7f611..af556eae7 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -2,21 +2,8 @@ package sqlc import ( "context" -) - -// BackendType is an enum that represents the type of database backend we're -// using. -type BackendType uint8 - -const ( - // BackendTypeUnknown indicates we're using an unknown backend. - BackendTypeUnknown BackendType = iota - // BackendTypeSqlite indicates we're using a SQLite backend. - BackendTypeSqlite - - // BackendTypePostgres indicates we're using a Postgres backend. - BackendTypePostgres + "github.com/lightningnetwork/lnd/sqldb/v2" ) // wrappedTX is a wrapper around a DBTX that also stores the database backend @@ -24,29 +11,24 @@ const ( type wrappedTX struct { DBTX - backendType BackendType + backendType sqldb.BackendType } // Backend returns the type of database backend we're using. -func (q *Queries) Backend() BackendType { +func (q *Queries) Backend() sqldb.BackendType { wtx, ok := q.db.(*wrappedTX) if !ok { // Shouldn't happen unless a new database backend type is added // but not initialized correctly. - return BackendTypeUnknown + return sqldb.BackendTypeUnknown } return wtx.backendType } -// NewSqlite creates a new Queries instance for a SQLite database. -func NewSqlite(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypeSqlite}} -} - -// NewPostgres creates a new Queries instance for a Postgres database. -func NewPostgres(db DBTX) *Queries { - return &Queries{db: &wrappedTX{db, BackendTypePostgres}} +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} } // CustomQueries defines a set of custom queries that we define in addition @@ -62,5 +44,5 @@ type CustomQueries interface { arg ListActionsParams) ([]Action, error) // Backend returns the type of the database backend used. - Backend() BackendType + Backend() sqldb.BackendType } diff --git a/db/sqlc/kvstores.sql.go b/db/sqlc/kvstores.sql.go index b2e6632f4..b46719eec 100644 --- a/db/sqlc/kvstores.sql.go +++ b/db/sqlc/kvstores.sql.go @@ -25,7 +25,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -33,7 +33,7 @@ type DeleteFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -42,7 +42,7 @@ func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -53,7 +53,7 @@ DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -68,28 +68,28 @@ func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGloba return err } -const deleteSessionKVStoreRecord = `-- name: DeleteSessionKVStoreRecord :exec +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id IS NULL ` -type DeleteSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, deleteSessionKVStoreRecord, +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } @@ -113,7 +113,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id = $4 + AND group_id = $4 AND feature_id = $5 ` @@ -121,7 +121,7 @@ type GetFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -130,7 +130,7 @@ func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVS arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) var value []byte @@ -144,7 +144,7 @@ FROM kvstores WHERE entry_key = $1 AND rule_id = $2 AND perm = $3 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -161,6 +161,35 @@ func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVSto return value, err } +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one INSERT INTO features (name) VALUES ($1) @@ -202,44 +231,15 @@ func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { return id, err } -const getSessionKVStoreRecord = `-- name: GetSessionKVStoreRecord :one -SELECT value -FROM kvstores -WHERE entry_key = $1 - AND rule_id = $2 - AND perm = $3 - AND session_id = $4 - AND feature_id IS NULL -` - -type GetSessionKVStoreRecordParams struct { - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 -} - -func (q *Queries) GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) { - row := q.db.QueryRowContext(ctx, getSessionKVStoreRecord, - arg.Key, - arg.RuleID, - arg.Perm, - arg.SessionID, - ) - var value []byte - err := row.Scan(&value) - return value, err -} - const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6) ` type InsertKVStoreRecordParams struct { Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte @@ -249,7 +249,7 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco _, err := q.db.ExecContext(ctx, insertKVStoreRecord, arg.Perm, arg.RuleID, - arg.SessionID, + arg.GroupID, arg.FeatureID, arg.EntryKey, arg.Value, @@ -257,13 +257,49 @@ func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreReco return err } +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.GroupID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id = $6 ` @@ -272,7 +308,7 @@ type UpdateFeatureKVStoreRecordParams struct { Key string RuleID int64 Perm bool - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 } @@ -282,7 +318,7 @@ func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeat arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, arg.FeatureID, ) return err @@ -294,7 +330,7 @@ SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL ` @@ -315,31 +351,31 @@ func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGloba return err } -const updateSessionKVStoreRecord = `-- name: UpdateSessionKVStoreRecord :exec +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = $2 AND rule_id = $3 AND perm = $4 - AND session_id = $5 + AND group_id = $5 AND feature_id IS NULL ` -type UpdateSessionKVStoreRecordParams struct { - Value []byte - Key string - RuleID int64 - Perm bool - SessionID sql.NullInt64 +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 } -func (q *Queries) UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error { - _, err := q.db.ExecContext(ctx, updateSessionKVStoreRecord, +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, arg.Value, arg.Key, arg.RuleID, arg.Perm, - arg.SessionID, + arg.GroupID, ) return err } diff --git a/db/sqlc/migrations/000003_kvstores.up.sql b/db/sqlc/migrations/000003_kvstores.up.sql index d2f0653a5..e49ed9622 100644 --- a/db/sqlc/migrations/000003_kvstores.up.sql +++ b/db/sqlc/migrations/000003_kvstores.up.sql @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS features ( CREATE UNIQUE INDEX IF NOT EXISTS features_name_idx ON features (name); -- kvstores houses key-value pairs under various namespaces determined --- by the rule name, session ID, and feature name. +-- by the rule name, group ID, and feature name. CREATE TABLE IF NOT EXISTS kvstores ( -- The auto incrementing primary key. id INTEGER PRIMARY KEY, @@ -35,15 +35,15 @@ CREATE TABLE IF NOT EXISTS kvstores ( -- kv_store. rule_id BIGINT REFERENCES rules(id) NOT NULL, - -- The session ID that this kv_store belongs to. - -- If this is set, then this kv_store is a session-specific + -- The group ID that this kv_store belongs to. + -- If this is set, then this kv_store is a session-group specific -- kv_store for the given rule. - session_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, + group_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, -- The feature name that this kv_store belongs to. -- If this is set, then this kv_store is a feature-specific - -- kvstore under the given session ID and rule name. - -- If this is set, then session_id must also be set. + -- kvstore under the given group ID and rule name. + -- If this is set, then group_id must also be set. feature_id BIGINT REFERENCES features(id), -- The key of the entry. @@ -54,4 +54,4 @@ CREATE TABLE IF NOT EXISTS kvstores ( ); CREATE UNIQUE INDEX IF NOT EXISTS kvstores_lookup_idx - ON kvstores (entry_key, rule_id, perm, session_id, feature_id); + ON kvstores (entry_key, rule_id, perm, group_id, feature_id); diff --git a/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.down.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql new file mode 100644 index 000000000..0d246b2dd --- /dev/null +++ b/db/sqlc/migrations_dev/000001_code_migration_kvdb_to_sql.up.sql @@ -0,0 +1 @@ +-- Comment to ensure the file created and picked up in the migration stream. \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 357360c9e..d19e66e10 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -63,7 +63,7 @@ type Kvstore struct { ID int64 Perm bool RuleID int64 - SessionID sql.NullInt64 + GroupID sql.NullInt64 FeatureID sql.NullInt64 EntryKey string Value []byte diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index df89d0898..d76d5e6e3 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -16,7 +16,7 @@ type Querier interface { DeleteAllTempKVStores(ctx context.Context) error DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error DeleteSessionsWithState(ctx context.Context, state int16) error GetAccount(ctx context.Context, id int64) (Account, error) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) @@ -29,6 +29,7 @@ type Querier interface { GetFeatureID(ctx context.Context, name string) (int64, error) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) @@ -40,7 +41,6 @@ type Querier interface { GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) - GetSessionKVStoreRecord(ctx context.Context, arg GetSessionKVStoreRecordParams) ([]byte, error) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) @@ -57,6 +57,7 @@ type Querier interface { ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) ListSessions(ctx context.Context) ([]Session, error) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) @@ -70,7 +71,7 @@ type Querier interface { UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error } diff --git a/db/sqlc/queries/kvstores.sql b/db/sqlc/queries/kvstores.sql index 7963e46a4..6acc27468 100644 --- a/db/sqlc/queries/kvstores.sql +++ b/db/sqlc/queries/kvstores.sql @@ -21,29 +21,33 @@ FROM features WHERE name = sqlc.arg('name'); -- name: InsertKVStoreRecord :exec -INSERT INTO kvstores (perm, rule_id, session_id, feature_id, entry_key, value) +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) VALUES ($1, $2, $3, $4, $5, $6); -- name: DeleteAllTempKVStores :exec DELETE FROM kvstores WHERE perm = false; +-- name: ListAllKVStoresRecords :many +SELECT * +FROM kvstores; + -- name: GetGlobalKVStoreRecord :one SELECT value FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: GetSessionKVStoreRecord :one +-- name: GetGroupKVStoreRecord :one SELECT value FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: GetFeatureKVStoreRecord :one @@ -52,7 +56,7 @@ FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: DeleteGlobalKVStoreRecord :exec @@ -60,15 +64,15 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: DeleteSessionKVStoreRecord :exec +-- name: DeleteGroupKVStoreRecord :exec DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: DeleteFeatureKVStoreRecord :exec @@ -76,7 +80,7 @@ DELETE FROM kvstores WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); -- name: UpdateGlobalKVStoreRecord :exec @@ -85,16 +89,16 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id IS NULL + AND group_id IS NULL AND feature_id IS NULL; --- name: UpdateSessionKVStoreRecord :exec +-- name: UpdateGroupKVStoreRecord :exec UPDATE kvstores SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id IS NULL; -- name: UpdateFeatureKVStoreRecord :exec @@ -103,5 +107,5 @@ SET value = $1 WHERE entry_key = sqlc.arg('key') AND rule_id = sqlc.arg('rule_id') AND perm = sqlc.arg('perm') - AND session_id = sqlc.arg('session_id') + AND group_id = sqlc.arg('group_id') AND feature_id = sqlc.arg('feature_id'); diff --git a/db/sqlcmig6/accounts.sql.go b/db/sqlcmig6/accounts.sql.go new file mode 100644 index 000000000..479c82b36 --- /dev/null +++ b/db/sqlcmig6/accounts.sql.go @@ -0,0 +1,390 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const addAccountInvoice = `-- name: AddAccountInvoice :exec +INSERT INTO account_invoices (account_id, hash) +VALUES ($1, $2) +` + +type AddAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error { + _, err := q.db.ExecContext(ctx, addAccountInvoice, arg.AccountID, arg.Hash) + return err +} + +const deleteAccount = `-- name: DeleteAccount :exec +DELETE FROM accounts +WHERE id = $1 +` + +func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAccount, id) + return err +} + +const deleteAccountPayment = `-- name: DeleteAccountPayment :exec +DELETE FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type DeleteAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, deleteAccountPayment, arg.Hash, arg.AccountID) + return err +} + +const getAccount = `-- name: GetAccount :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE id = $1 +` + +func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccount, id) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountByLabel = `-- name: GetAccountByLabel :one +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +WHERE label = $1 +` + +func (q *Queries) GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccountByLabel, label) + var i Account + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ) + return i, err +} + +const getAccountIDByAlias = `-- name: GetAccountIDByAlias :one +SELECT id +FROM accounts +WHERE alias = $1 +` + +func (q *Queries) GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getAccountIndex = `-- name: GetAccountIndex :one +SELECT value +FROM account_indices +WHERE name = $1 +` + +func (q *Queries) GetAccountIndex(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getAccountIndex, name) + var value int64 + err := row.Scan(&value) + return value, err +} + +const getAccountInvoice = `-- name: GetAccountInvoice :one +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 + AND hash = $2 +` + +type GetAccountInvoiceParams struct { + AccountID int64 + Hash []byte +} + +func (q *Queries) GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) { + row := q.db.QueryRowContext(ctx, getAccountInvoice, arg.AccountID, arg.Hash) + var i AccountInvoice + err := row.Scan(&i.AccountID, &i.Hash) + return i, err +} + +const getAccountPayment = `-- name: GetAccountPayment :one +SELECT account_id, hash, status, full_amount_msat FROM account_payments +WHERE hash = $1 +AND account_id = $2 +` + +type GetAccountPaymentParams struct { + Hash []byte + AccountID int64 +} + +func (q *Queries) GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) { + row := q.db.QueryRowContext(ctx, getAccountPayment, arg.Hash, arg.AccountID) + var i AccountPayment + err := row.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ) + return i, err +} + +const insertAccount = `-- name: InsertAccount :one +INSERT INTO accounts (type, initial_balance_msat, current_balance_msat, last_updated, label, alias, expiration) +VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id +` + +type InsertAccountParams struct { + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Label sql.NullString + Alias int64 + Expiration time.Time +} + +func (q *Queries) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAccount, + arg.Type, + arg.InitialBalanceMsat, + arg.CurrentBalanceMsat, + arg.LastUpdated, + arg.Label, + arg.Alias, + arg.Expiration, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const listAccountInvoices = `-- name: ListAccountInvoices :many +SELECT account_id, hash +FROM account_invoices +WHERE account_id = $1 +` + +func (q *Queries) ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) { + rows, err := q.db.QueryContext(ctx, listAccountInvoices, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountInvoice + for rows.Next() { + var i AccountInvoice + if err := rows.Scan(&i.AccountID, &i.Hash); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAccountPayments = `-- name: ListAccountPayments :many +SELECT account_id, hash, status, full_amount_msat +FROM account_payments +WHERE account_id = $1 +` + +func (q *Queries) ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) { + rows, err := q.db.QueryContext(ctx, listAccountPayments, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AccountPayment + for rows.Next() { + var i AccountPayment + if err := rows.Scan( + &i.AccountID, + &i.Hash, + &i.Status, + &i.FullAmountMsat, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listAllAccounts = `-- name: ListAllAccounts :many +SELECT id, alias, label, type, initial_balance_msat, current_balance_msat, last_updated, expiration +FROM accounts +ORDER BY id +` + +func (q *Queries) ListAllAccounts(ctx context.Context) ([]Account, error) { + rows, err := q.db.QueryContext(ctx, listAllAccounts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Account + for rows.Next() { + var i Account + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.Type, + &i.InitialBalanceMsat, + &i.CurrentBalanceMsat, + &i.LastUpdated, + &i.Expiration, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setAccountIndex = `-- name: SetAccountIndex :exec +INSERT INTO account_indices (name, value) +VALUES ($1, $2) + ON CONFLICT (name) +DO UPDATE SET value = $2 +` + +type SetAccountIndexParams struct { + Name string + Value int64 +} + +func (q *Queries) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error { + _, err := q.db.ExecContext(ctx, setAccountIndex, arg.Name, arg.Value) + return err +} + +const updateAccountBalance = `-- name: UpdateAccountBalance :one +UPDATE accounts +SET current_balance_msat = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountBalanceParams struct { + CurrentBalanceMsat int64 + ID int64 +} + +func (q *Queries) UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountBalance, arg.CurrentBalanceMsat, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountExpiry = `-- name: UpdateAccountExpiry :one +UPDATE accounts +SET expiration = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountExpiryParams struct { + Expiration time.Time + ID int64 +} + +func (q *Queries) UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountExpiry, arg.Expiration, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const updateAccountLastUpdate = `-- name: UpdateAccountLastUpdate :one +UPDATE accounts +SET last_updated = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateAccountLastUpdateParams struct { + LastUpdated time.Time + ID int64 +} + +func (q *Queries) UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) { + row := q.db.QueryRowContext(ctx, updateAccountLastUpdate, arg.LastUpdated, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertAccountPayment = `-- name: UpsertAccountPayment :exec +INSERT INTO account_payments (account_id, hash, status, full_amount_msat) +VALUES ($1, $2, $3, $4) +ON CONFLICT (account_id, hash) +DO UPDATE SET status = $3, full_amount_msat = $4 +` + +type UpsertAccountPaymentParams struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +func (q *Queries) UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error { + _, err := q.db.ExecContext(ctx, upsertAccountPayment, + arg.AccountID, + arg.Hash, + arg.Status, + arg.FullAmountMsat, + ) + return err +} diff --git a/db/sqlcmig6/actions.sql.go b/db/sqlcmig6/actions.sql.go new file mode 100644 index 000000000..a39d51e5d --- /dev/null +++ b/db/sqlcmig6/actions.sql.go @@ -0,0 +1,73 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const insertAction = `-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id +` + +type InsertActionParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +func (q *Queries) InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAction, + arg.SessionID, + arg.AccountID, + arg.MacaroonIdentifier, + arg.ActorName, + arg.FeatureName, + arg.ActionTrigger, + arg.Intent, + arg.StructuredJsonData, + arg.RpcMethod, + arg.RpcParamsJson, + arg.CreatedAt, + arg.ActionState, + arg.ErrorReason, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const setActionState = `-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3 +` + +type SetActionStateParams struct { + ActionState int16 + ErrorReason sql.NullString + ID int64 +} + +func (q *Queries) SetActionState(ctx context.Context, arg SetActionStateParams) error { + _, err := q.db.ExecContext(ctx, setActionState, arg.ActionState, arg.ErrorReason, arg.ID) + return err +} diff --git a/db/sqlcmig6/actions_custom.go b/db/sqlcmig6/actions_custom.go new file mode 100644 index 000000000..f01772d51 --- /dev/null +++ b/db/sqlcmig6/actions_custom.go @@ -0,0 +1,210 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "strconv" + "strings" +) + +// ActionQueryParams defines the parameters for querying actions. +type ActionQueryParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + FeatureName sql.NullString + ActorName sql.NullString + RpcMethod sql.NullString + State sql.NullInt16 + EndTime sql.NullTime + StartTime sql.NullTime + GroupID sql.NullInt64 +} + +// ListActionsParams defines the parameters for listing actions, including +// the ActionQueryParams for filtering and a Pagination struct for +// pagination. The Reversed field indicates whether the results should be +// returned in reverse order based on the created_at timestamp. +type ListActionsParams struct { + ActionQueryParams + Reversed bool + *Pagination +} + +// Pagination defines the pagination parameters for listing actions. +type Pagination struct { + NumOffset int32 + NumLimit int32 +} + +// ListActions retrieves a list of actions based on the provided +// ListActionsParams. +func (q *Queries) ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) { + + query, args := buildListActionsQuery(arg) + rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Action + for rows.Next() { + var i Action + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.AccountID, + &i.MacaroonIdentifier, + &i.ActorName, + &i.FeatureName, + &i.ActionTrigger, + &i.Intent, + &i.StructuredJsonData, + &i.RpcMethod, + &i.RpcParamsJson, + &i.CreatedAt, + &i.ActionState, + &i.ErrorReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// CountActions returns the number of actions that match the provided +// ActionQueryParams. +func (q *Queries) CountActions(ctx context.Context, + arg ActionQueryParams) (int64, error) { + + query, args := buildActionsQuery(arg, true) + row := q.db.QueryRowContext(ctx, query, args...) + + var count int64 + err := row.Scan(&count) + + return count, err +} + +// buildActionsQuery constructs a SQL query to retrieve actions based on the +// provided parameters. We do this manually so that if, for example, we have +// a sessionID we are filtering by, then this appears in the query as: +// `WHERE a.session_id = ?` which will properly make use of the underlying +// index. If we were instead to use a single SQLC query, it would include many +// WHERE clauses like: +// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)". +// This would use the index if run against postres but not when run against +// sqlite. +// +// The 'count' param indicates whether the query should return a count of +// actions that match the criteria or the actions themselves. +func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) { + var ( + conditions []string + args []any + ) + + if params.SessionID.Valid { + conditions = append(conditions, "a.session_id = ?") + args = append(args, params.SessionID.Int64) + } + if params.AccountID.Valid { + conditions = append(conditions, "a.account_id = ?") + args = append(args, params.AccountID.Int64) + } + if params.FeatureName.Valid { + conditions = append(conditions, "a.feature_name = ?") + args = append(args, params.FeatureName.String) + } + if params.ActorName.Valid { + conditions = append(conditions, "a.actor_name = ?") + args = append(args, params.ActorName.String) + } + if params.RpcMethod.Valid { + conditions = append(conditions, "a.rpc_method = ?") + args = append(args, params.RpcMethod.String) + } + if params.State.Valid { + conditions = append(conditions, "a.action_state = ?") + args = append(args, params.State.Int16) + } + if params.EndTime.Valid { + conditions = append(conditions, "a.created_at <= ?") + args = append(args, params.EndTime.Time) + } + if params.StartTime.Valid { + conditions = append(conditions, "a.created_at >= ?") + args = append(args, params.StartTime.Time) + } + if params.GroupID.Valid { + conditions = append(conditions, ` + EXISTS ( + SELECT 1 + FROM sessions s + WHERE s.id = a.session_id AND s.group_id = ? + )`) + args = append(args, params.GroupID.Int64) + } + + query := "SELECT a.* FROM actions a" + if count { + query = "SELECT COUNT(*) FROM actions a" + } + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + + return query, args +} + +// buildListActionsQuery constructs a SQL query to retrieve a list of actions +// based on the provided parameters. It builds upon the `buildActionsQuery` +// function, adding pagination and ordering based on the reversed parameter. +func buildListActionsQuery(params ListActionsParams) (string, []interface{}) { + query, args := buildActionsQuery(params.ActionQueryParams, false) + + // Determine order direction. + order := "ASC" + if params.Reversed { + order = "DESC" + } + query += " ORDER BY a.created_at " + order + + // Maybe paginate. + if params.Pagination != nil { + query += " LIMIT ? OFFSET ?" + args = append(args, params.NumLimit, params.NumOffset) + } + + return query, args +} + +// fillPlaceHolders replaces all '?' placeholders in the SQL query with +// positional placeholders like $1, $2, etc. This is necessary for +// compatibility with Postgres. +func fillPlaceHolders(query string) string { + var ( + sb strings.Builder + argNum = 1 + ) + + for i := range len(query) { + if query[i] != '?' { + sb.WriteByte(query[i]) + continue + } + + sb.WriteString("$") + sb.WriteString(strconv.Itoa(argNum)) + argNum++ + } + + return sb.String() +} diff --git a/db/sqlcmig6/db.go b/db/sqlcmig6/db.go new file mode 100644 index 000000000..82ff72dd8 --- /dev/null +++ b/db/sqlcmig6/db.go @@ -0,0 +1,27 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/db/sqlcmig6/db_custom.go b/db/sqlcmig6/db_custom.go new file mode 100644 index 000000000..e128cb290 --- /dev/null +++ b/db/sqlcmig6/db_custom.go @@ -0,0 +1,48 @@ +package sqlcmig6 + +import ( + "context" + + "github.com/lightningnetwork/lnd/sqldb/v2" +) + +// wrappedTX is a wrapper around a DBTX that also stores the database backend +// type. +type wrappedTX struct { + DBTX + + backendType sqldb.BackendType +} + +// Backend returns the type of database backend we're using. +func (q *Queries) Backend() sqldb.BackendType { + wtx, ok := q.db.(*wrappedTX) + if !ok { + // Shouldn't happen unless a new database backend type is added + // but not initialized correctly. + return sqldb.BackendTypeUnknown + } + + return wtx.backendType +} + +// NewForType creates a new Queries instance for the given database type. +func NewForType(db DBTX, typ sqldb.BackendType) *Queries { + return &Queries{db: &wrappedTX{db, typ}} +} + +// CustomQueries defines a set of custom queries that we define in addition +// to the ones generated by sqlc. +type CustomQueries interface { + // CountActions returns the number of actions that match the provided + // ActionQueryParams. + CountActions(ctx context.Context, arg ActionQueryParams) (int64, error) + + // ListActions retrieves a list of actions based on the provided + // ListActionsParams. + ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) + + // Backend returns the type of the database backend used. + Backend() sqldb.BackendType +} diff --git a/db/sqlcmig6/kvstores.sql.go b/db/sqlcmig6/kvstores.sql.go new file mode 100644 index 000000000..3c076c5a4 --- /dev/null +++ b/db/sqlcmig6/kvstores.sql.go @@ -0,0 +1,376 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +const deleteAllTempKVStores = `-- name: DeleteAllTempKVStores :exec +DELETE FROM kvstores +WHERE perm = false +` + +func (q *Queries) DeleteAllTempKVStores(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteAllTempKVStores) + return err +} + +const deleteFeatureKVStoreRecord = `-- name: DeleteFeatureKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type DeleteFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const deleteGlobalKVStoreRecord = `-- name: DeleteGlobalKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type DeleteGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + return err +} + +const deleteGroupKVStoreRecord = `-- name: DeleteGroupKVStoreRecord :exec +DELETE FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type DeleteGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} + +const getFeatureID = `-- name: GetFeatureID :one +SELECT id +FROM features +WHERE name = $1 +` + +func (q *Queries) GetFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getFeatureKVStoreRecord = `-- name: GetFeatureKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id = $5 +` + +type GetFeatureKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getFeatureKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGlobalKVStoreRecord = `-- name: GetGlobalKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id IS NULL + AND feature_id IS NULL +` + +type GetGlobalKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGlobalKVStoreRecord, arg.Key, arg.RuleID, arg.Perm) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getGroupKVStoreRecord = `-- name: GetGroupKVStoreRecord :one +SELECT value +FROM kvstores +WHERE entry_key = $1 + AND rule_id = $2 + AND perm = $3 + AND group_id = $4 + AND feature_id IS NULL +` + +type GetGroupKVStoreRecordParams struct { + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getGroupKVStoreRecord, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + var value []byte + err := row.Scan(&value) + return value, err +} + +const getOrInsertFeatureID = `-- name: GetOrInsertFeatureID :one +INSERT INTO features (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertFeatureID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getOrInsertRuleID = `-- name: GetOrInsertRuleID :one +INSERT INTO rules (name) +VALUES ($1) +ON CONFLICT(name) DO UPDATE SET name = excluded.name +RETURNING id +` + +func (q *Queries) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getOrInsertRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getRuleID = `-- name: GetRuleID :one +SELECT id +FROM rules +WHERE name = $1 +` + +func (q *Queries) GetRuleID(ctx context.Context, name string) (int64, error) { + row := q.db.QueryRowContext(ctx, getRuleID, name) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertKVStoreRecord = `-- name: InsertKVStoreRecord :exec +INSERT INTO kvstores (perm, rule_id, group_id, feature_id, entry_key, value) +VALUES ($1, $2, $3, $4, $5, $6) +` + +type InsertKVStoreRecordParams struct { + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +func (q *Queries) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, insertKVStoreRecord, + arg.Perm, + arg.RuleID, + arg.GroupID, + arg.FeatureID, + arg.EntryKey, + arg.Value, + ) + return err +} + +const listAllKVStoresRecords = `-- name: ListAllKVStoresRecords :many +SELECT id, perm, rule_id, group_id, feature_id, entry_key, value +FROM kvstores +` + +func (q *Queries) ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) { + rows, err := q.db.QueryContext(ctx, listAllKVStoresRecords) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Kvstore + for rows.Next() { + var i Kvstore + if err := rows.Scan( + &i.ID, + &i.Perm, + &i.RuleID, + &i.GroupID, + &i.FeatureID, + &i.EntryKey, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateFeatureKVStoreRecord = `-- name: UpdateFeatureKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id = $6 +` + +type UpdateFeatureKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 + FeatureID sql.NullInt64 +} + +func (q *Queries) UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateFeatureKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + arg.FeatureID, + ) + return err +} + +const updateGlobalKVStoreRecord = `-- name: UpdateGlobalKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id IS NULL + AND feature_id IS NULL +` + +type UpdateGlobalKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool +} + +func (q *Queries) UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGlobalKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + ) + return err +} + +const updateGroupKVStoreRecord = `-- name: UpdateGroupKVStoreRecord :exec +UPDATE kvstores +SET value = $1 +WHERE entry_key = $2 + AND rule_id = $3 + AND perm = $4 + AND group_id = $5 + AND feature_id IS NULL +` + +type UpdateGroupKVStoreRecordParams struct { + Value []byte + Key string + RuleID int64 + Perm bool + GroupID sql.NullInt64 +} + +func (q *Queries) UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error { + _, err := q.db.ExecContext(ctx, updateGroupKVStoreRecord, + arg.Value, + arg.Key, + arg.RuleID, + arg.Perm, + arg.GroupID, + ) + return err +} diff --git a/db/sqlcmig6/models.go b/db/sqlcmig6/models.go new file mode 100644 index 000000000..9e57c28eb --- /dev/null +++ b/db/sqlcmig6/models.go @@ -0,0 +1,124 @@ +package sqlcmig6 + +import ( + "database/sql" + "time" +) + +type Account struct { + ID int64 + Alias int64 + Label sql.NullString + Type int16 + InitialBalanceMsat int64 + CurrentBalanceMsat int64 + LastUpdated time.Time + Expiration time.Time +} + +type AccountIndex struct { + Name string + Value int64 +} + +type AccountInvoice struct { + AccountID int64 + Hash []byte +} + +type AccountPayment struct { + AccountID int64 + Hash []byte + Status int16 + FullAmountMsat int64 +} + +type Action struct { + ID int64 + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +type Feature struct { + ID int64 + Name string +} + +type Kvstore struct { + ID int64 + Perm bool + RuleID int64 + GroupID sql.NullInt64 + FeatureID sql.NullInt64 + EntryKey string + Value []byte +} + +type PrivacyPair struct { + GroupID int64 + RealVal string + PseudoVal string +} + +type Rule struct { + ID int64 + Name string +} + +type Session struct { + ID int64 + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + RevokedAt sql.NullTime + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + AccountID sql.NullInt64 + GroupID sql.NullInt64 +} + +type SessionFeatureConfig struct { + SessionID int64 + FeatureName string + Config []byte +} + +type SessionMacaroonCaveat struct { + ID int64 + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +type SessionMacaroonPermission struct { + ID int64 + SessionID int64 + Entity string + Action string +} + +type SessionPrivacyFlag struct { + SessionID int64 + Flag int32 +} diff --git a/db/sqlcmig6/privacy_paris.sql.go b/db/sqlcmig6/privacy_paris.sql.go new file mode 100644 index 000000000..20af09a2f --- /dev/null +++ b/db/sqlcmig6/privacy_paris.sql.go @@ -0,0 +1,91 @@ +package sqlcmig6 + +import ( + "context" +) + +const getAllPrivacyPairs = `-- name: GetAllPrivacyPairs :many +SELECT real_val, pseudo_val +FROM privacy_pairs +WHERE group_id = $1 +` + +type GetAllPrivacyPairsRow struct { + RealVal string + PseudoVal string +} + +func (q *Queries) GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) { + rows, err := q.db.QueryContext(ctx, getAllPrivacyPairs, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllPrivacyPairsRow + for rows.Next() { + var i GetAllPrivacyPairsRow + if err := rows.Scan(&i.RealVal, &i.PseudoVal); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getPseudoForReal = `-- name: GetPseudoForReal :one +SELECT pseudo_val +FROM privacy_pairs +WHERE group_id = $1 AND real_val = $2 +` + +type GetPseudoForRealParams struct { + GroupID int64 + RealVal string +} + +func (q *Queries) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) { + row := q.db.QueryRowContext(ctx, getPseudoForReal, arg.GroupID, arg.RealVal) + var pseudo_val string + err := row.Scan(&pseudo_val) + return pseudo_val, err +} + +const getRealForPseudo = `-- name: GetRealForPseudo :one +SELECT real_val +FROM privacy_pairs +WHERE group_id = $1 AND pseudo_val = $2 +` + +type GetRealForPseudoParams struct { + GroupID int64 + PseudoVal string +} + +func (q *Queries) GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) { + row := q.db.QueryRowContext(ctx, getRealForPseudo, arg.GroupID, arg.PseudoVal) + var real_val string + err := row.Scan(&real_val) + return real_val, err +} + +const insertPrivacyPair = `-- name: InsertPrivacyPair :exec +INSERT INTO privacy_pairs (group_id, real_val, pseudo_val) +VALUES ($1, $2, $3) +` + +type InsertPrivacyPairParams struct { + GroupID int64 + RealVal string + PseudoVal string +} + +func (q *Queries) InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error { + _, err := q.db.ExecContext(ctx, insertPrivacyPair, arg.GroupID, arg.RealVal, arg.PseudoVal) + return err +} diff --git a/db/sqlcmig6/querier.go b/db/sqlcmig6/querier.go new file mode 100644 index 000000000..57e229b5f --- /dev/null +++ b/db/sqlcmig6/querier.go @@ -0,0 +1,75 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" +) + +type Querier interface { + AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error + DeleteAllTempKVStores(ctx context.Context) error + DeleteFeatureKVStoreRecord(ctx context.Context, arg DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg DeleteGroupKVStoreRecordParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccount(ctx context.Context, id int64) (Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) + GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg GetGroupKVStoreRecordParams) ([]byte, error) + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) + GetRuleID(ctx context.Context, name string) (int64, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) + GetSessionByID(ctx context.Context, id int64) (Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) + GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) + InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) + InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error + InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error + InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) + InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error + ListAccountInvoices(ctx context.Context, accountID int64) ([]AccountInvoice, error) + ListAccountPayments(ctx context.Context, accountID int64) ([]AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]Account, error) + ListAllKVStoresRecords(ctx context.Context) ([]Kvstore, error) + ListSessions(ctx context.Context) ([]Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]Session, error) + ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) + SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + SetActionState(ctx context.Context, arg SetActionStateParams) error + SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error + SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error + SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error + UpdateAccountBalance(ctx context.Context, arg UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg UpdateAccountLastUpdateParams) (int64, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg UpdateGroupKVStoreRecordParams) error + UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error + UpsertAccountPayment(ctx context.Context, arg UpsertAccountPaymentParams) error +} + +var _ Querier = (*Queries)(nil) diff --git a/db/sqlcmig6/sessions.sql.go b/db/sqlcmig6/sessions.sql.go new file mode 100644 index 000000000..bc492043e --- /dev/null +++ b/db/sqlcmig6/sessions.sql.go @@ -0,0 +1,675 @@ +package sqlcmig6 + +import ( + "context" + "database/sql" + "time" +) + +const deleteSessionsWithState = `-- name: DeleteSessionsWithState :exec +DELETE FROM sessions +WHERE state = $1 +` + +func (q *Queries) DeleteSessionsWithState(ctx context.Context, state int16) error { + _, err := q.db.ExecContext(ctx, deleteSessionsWithState, state) + return err +} + +const getAliasBySessionID = `-- name: GetAliasBySessionID :one +SELECT alias FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) { + row := q.db.QueryRowContext(ctx, getAliasBySessionID, id) + var alias []byte + err := row.Scan(&alias) + return alias, err +} + +const getSessionAliasesInGroup = `-- name: GetSessionAliasesInGroup :many +SELECT alias FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) { + rows, err := q.db.QueryContext(ctx, getSessionAliasesInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items [][]byte + for rows.Next() { + var alias []byte + if err := rows.Scan(&alias); err != nil { + return nil, err + } + items = append(items, alias) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionByAlias = `-- name: GetSessionByAlias :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByAlias, alias) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByID = `-- name: GetSessionByID :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE id = $1 +` + +func (q *Queries) GetSessionByID(ctx context.Context, id int64) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByID, id) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionByLocalPublicKey = `-- name: GetSessionByLocalPublicKey :one +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE local_public_key = $1 +` + +func (q *Queries) GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (Session, error) { + row := q.db.QueryRowContext(ctx, getSessionByLocalPublicKey, localPublicKey) + var i Session + err := row.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ) + return i, err +} + +const getSessionFeatureConfigs = `-- name: GetSessionFeatureConfigs :many +SELECT session_id, feature_name, config FROM session_feature_configs +WHERE session_id = $1 +` + +func (q *Queries) GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]SessionFeatureConfig, error) { + rows, err := q.db.QueryContext(ctx, getSessionFeatureConfigs, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionFeatureConfig + for rows.Next() { + var i SessionFeatureConfig + if err := rows.Scan(&i.SessionID, &i.FeatureName, &i.Config); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionIDByAlias = `-- name: GetSessionIDByAlias :one +SELECT id FROM sessions +WHERE alias = $1 +` + +func (q *Queries) GetSessionIDByAlias(ctx context.Context, alias []byte) (int64, error) { + row := q.db.QueryRowContext(ctx, getSessionIDByAlias, alias) + var id int64 + err := row.Scan(&id) + return id, err +} + +const getSessionMacaroonCaveats = `-- name: GetSessionMacaroonCaveats :many +SELECT id, session_id, caveat_id, verification_id, location FROM session_macaroon_caveats +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]SessionMacaroonCaveat, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonCaveats, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonCaveat + for rows.Next() { + var i SessionMacaroonCaveat + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.CaveatID, + &i.VerificationID, + &i.Location, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionMacaroonPermissions = `-- name: GetSessionMacaroonPermissions :many +SELECT id, session_id, entity, action FROM session_macaroon_permissions +WHERE session_id = $1 +` + +func (q *Queries) GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]SessionMacaroonPermission, error) { + rows, err := q.db.QueryContext(ctx, getSessionMacaroonPermissions, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionMacaroonPermission + for rows.Next() { + var i SessionMacaroonPermission + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Entity, + &i.Action, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionPrivacyFlags = `-- name: GetSessionPrivacyFlags :many +SELECT session_id, flag FROM session_privacy_flags +WHERE session_id = $1 +` + +func (q *Queries) GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) { + rows, err := q.db.QueryContext(ctx, getSessionPrivacyFlags, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SessionPrivacyFlag + for rows.Next() { + var i SessionPrivacyFlag + if err := rows.Scan(&i.SessionID, &i.Flag); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionsInGroup = `-- name: GetSessionsInGroup :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE group_id = $1 +` + +func (q *Queries) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, getSessionsInGroup, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertSession = `-- name: InsertSession :one +INSERT INTO sessions ( + alias, label, state, type, expiry, created_at, + server_address, dev_server, macaroon_root_key, pairing_secret, + local_private_key, local_public_key, remote_public_key, privacy, group_id, account_id +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, $10, $11, $12, + $13, $14, $15, $16 +) RETURNING id +` + +type InsertSessionParams struct { + Alias []byte + Label string + State int16 + Type int16 + Expiry time.Time + CreatedAt time.Time + ServerAddress string + DevServer bool + MacaroonRootKey int64 + PairingSecret []byte + LocalPrivateKey []byte + LocalPublicKey []byte + RemotePublicKey []byte + Privacy bool + GroupID sql.NullInt64 + AccountID sql.NullInt64 +} + +func (q *Queries) InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertSession, + arg.Alias, + arg.Label, + arg.State, + arg.Type, + arg.Expiry, + arg.CreatedAt, + arg.ServerAddress, + arg.DevServer, + arg.MacaroonRootKey, + arg.PairingSecret, + arg.LocalPrivateKey, + arg.LocalPublicKey, + arg.RemotePublicKey, + arg.Privacy, + arg.GroupID, + arg.AccountID, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertSessionFeatureConfig = `-- name: InsertSessionFeatureConfig :exec +INSERT INTO session_feature_configs ( + session_id, feature_name, config +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionFeatureConfigParams struct { + SessionID int64 + FeatureName string + Config []byte +} + +func (q *Queries) InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error { + _, err := q.db.ExecContext(ctx, insertSessionFeatureConfig, arg.SessionID, arg.FeatureName, arg.Config) + return err +} + +const insertSessionMacaroonCaveat = `-- name: InsertSessionMacaroonCaveat :exec +INSERT INTO session_macaroon_caveats ( + session_id, caveat_id, verification_id, location +) VALUES ( + $1, $2, $3, $4 +) +` + +type InsertSessionMacaroonCaveatParams struct { + SessionID int64 + CaveatID []byte + VerificationID []byte + Location sql.NullString +} + +func (q *Queries) InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonCaveat, + arg.SessionID, + arg.CaveatID, + arg.VerificationID, + arg.Location, + ) + return err +} + +const insertSessionMacaroonPermission = `-- name: InsertSessionMacaroonPermission :exec +INSERT INTO session_macaroon_permissions ( + session_id, entity, action +) VALUES ( + $1, $2, $3 +) +` + +type InsertSessionMacaroonPermissionParams struct { + SessionID int64 + Entity string + Action string +} + +func (q *Queries) InsertSessionMacaroonPermission(ctx context.Context, arg InsertSessionMacaroonPermissionParams) error { + _, err := q.db.ExecContext(ctx, insertSessionMacaroonPermission, arg.SessionID, arg.Entity, arg.Action) + return err +} + +const insertSessionPrivacyFlag = `-- name: InsertSessionPrivacyFlag :exec +INSERT INTO session_privacy_flags ( + session_id, flag +) VALUES ( + $1, $2 +) +` + +type InsertSessionPrivacyFlagParams struct { + SessionID int64 + Flag int32 +} + +func (q *Queries) InsertSessionPrivacyFlag(ctx context.Context, arg InsertSessionPrivacyFlagParams) error { + _, err := q.db.ExecContext(ctx, insertSessionPrivacyFlag, arg.SessionID, arg.Flag) + return err +} + +const listSessions = `-- name: ListSessions :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +ORDER BY created_at +` + +func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByState = `-- name: ListSessionsByState :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE state = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByState(ctx context.Context, state int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByState, state) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessionsByType = `-- name: ListSessionsByType :many +SELECT id, alias, label, state, type, expiry, created_at, revoked_at, server_address, dev_server, macaroon_root_key, pairing_secret, local_private_key, local_public_key, remote_public_key, privacy, account_id, group_id FROM sessions +WHERE type = $1 +ORDER BY created_at +` + +func (q *Queries) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessionsByType, type_) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Session + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Alias, + &i.Label, + &i.State, + &i.Type, + &i.Expiry, + &i.CreatedAt, + &i.RevokedAt, + &i.ServerAddress, + &i.DevServer, + &i.MacaroonRootKey, + &i.PairingSecret, + &i.LocalPrivateKey, + &i.LocalPublicKey, + &i.RemotePublicKey, + &i.Privacy, + &i.AccountID, + &i.GroupID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setSessionGroupID = `-- name: SetSessionGroupID :exec +UPDATE sessions +SET group_id = $1 +WHERE id = $2 +` + +type SetSessionGroupIDParams struct { + GroupID sql.NullInt64 + ID int64 +} + +func (q *Queries) SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error { + _, err := q.db.ExecContext(ctx, setSessionGroupID, arg.GroupID, arg.ID) + return err +} + +const setSessionRemotePublicKey = `-- name: SetSessionRemotePublicKey :exec +UPDATE sessions +SET remote_public_key = $1 +WHERE id = $2 +` + +type SetSessionRemotePublicKeyParams struct { + RemotePublicKey []byte + ID int64 +} + +func (q *Queries) SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error { + _, err := q.db.ExecContext(ctx, setSessionRemotePublicKey, arg.RemotePublicKey, arg.ID) + return err +} + +const setSessionRevokedAt = `-- name: SetSessionRevokedAt :exec +UPDATE sessions +SET revoked_at = $1 +WHERE id = $2 +` + +type SetSessionRevokedAtParams struct { + RevokedAt sql.NullTime + ID int64 +} + +func (q *Queries) SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error { + _, err := q.db.ExecContext(ctx, setSessionRevokedAt, arg.RevokedAt, arg.ID) + return err +} + +const updateSessionState = `-- name: UpdateSessionState :exec +UPDATE sessions +SET state = $1 +WHERE id = $2 +` + +type UpdateSessionStateParams struct { + State int16 + ID int64 +} + +func (q *Queries) UpdateSessionState(ctx context.Context, arg UpdateSessionStateParams) error { + _, err := q.db.ExecContext(ctx, updateSessionState, arg.State, arg.ID) + return err +} diff --git a/db/sqlite.go b/db/sqlite.go index 803362fa8..a65271548 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" _ "modernc.org/sqlite" // Register relevant drivers. ) @@ -132,7 +133,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(defaultConnMaxLifetime) - queries := sqlc.NewSqlite(db) + queries := sqlc.NewForType(db, sqldb.BackendTypeSqlite) s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ @@ -140,16 +141,7 @@ func NewSqliteStore(cfg *SqliteConfig) (*SqliteStore, error) { Queries: queries, }, } - - // Now that the database is open, populate the database with our set of - // schemas based on our embedded in-memory file system. - if !cfg.SkipMigrations { - if err := s.ExecuteMigrations(s.backupAndMigrate); err != nil { - return nil, fmt.Errorf("error executing migrations: "+ - "%w", err) - } - } - + return s, nil } @@ -243,7 +235,7 @@ func (s *SqliteStore) ExecuteMigrations(target MigrationTarget, return fmt.Errorf("error creating sqlite migration: %w", err) } - sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) + sqliteFS := newReplacerFS(SqlSchemas, sqliteSchemaReplacements) return applyMigrations( sqliteFS, driver, "sqlc/migrations", "sqlite", target, opts, ) diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go index 75c9d0a6d..eb0294af7 100644 --- a/firewalldb/actions_sql.go +++ b/firewalldb/actions_sql.go @@ -10,9 +10,10 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLAccountQueries is a subset of the sqlc.Queries interface that can be used @@ -22,6 +23,13 @@ type SQLAccountQueries interface { GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) } +// SQLMig6AccountQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the accounts table. +type SQLMig6AccountQueries interface { + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) +} + // SQLActionQueries is a subset of the sqlc.Queries interface that can be used // to interact with action related tables. // @@ -36,6 +44,20 @@ type SQLActionQueries interface { CountActions(ctx context.Context, arg sqlc.ActionQueryParams) (int64, error) } +// SQLMig6ActionQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with action related tables. +// +//nolint:lll +type SQLMig6ActionQueries interface { + SQLSessionQueries + SQLMig6AccountQueries + + InsertAction(ctx context.Context, arg sqlcmig6.InsertActionParams) (int64, error) + SetActionState(ctx context.Context, arg sqlcmig6.SetActionStateParams) error + ListActions(ctx context.Context, arg sqlcmig6.ListActionsParams) ([]sqlcmig6.Action, error) + CountActions(ctx context.Context, arg sqlcmig6.ActionQueryParams) (int64, error) +} + // sqlActionLocator helps us find an action in the SQL DB. type sqlActionLocator struct { // id is the DB level ID of the action. @@ -167,7 +189,7 @@ func (s *SQLDB) AddAction(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -202,7 +224,7 @@ func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, Valid: errReason != "", }, }) - }) + }, sqldb.NoOpReset) } // ListActions returns a list of Actions. The query IndexOffset and MaxNum @@ -350,7 +372,7 @@ func (s *SQLDB) ListActions(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) return actions, lastIndex, uint64(totalCount), err } diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index c27e53e96..69990c1da 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -28,9 +28,6 @@ func TestActionStorage(t *testing.T) { sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Assert that attempting to add an action for a session that does not // exist returns an error. @@ -198,9 +195,6 @@ func TestListActions(t *testing.T) { sessDB := session.NewTestDB(t, clock) db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // Add 2 sessions that we can reference. sess1, err := sessDB.NewSession( @@ -466,9 +460,6 @@ func TestListGroupActions(t *testing.T) { } db := NewTestDBWithSessions(t, sessDB, clock) - t.Cleanup(func() { - _ = db.Close() - }) // There should not be any actions in group 1 yet. al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) diff --git a/firewalldb/db.go b/firewalldb/db.go index b8d9ed06f..a8349a538 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,29 +14,21 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// firewallDBs is an interface that groups the RulesDB and PrivacyMapper -// interfaces. -type firewallDBs interface { - RulesDB - PrivacyMapper - ActionDB -} - // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - firewallDBs + FirewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the // underlying rules' and privacy mapper databases. -func NewDB(dbs firewallDBs) *DB { +func NewDB(dbs FirewallDBs) *DB { return &DB{ - firewallDBs: dbs, + FirewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 5ee729e91..c2955bdc6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -134,3 +134,11 @@ type ActionDB interface { // and feature name. GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB } + +// FirewallDBs is an interface that groups the RulesDB, PrivacyMapper and +// ActionDB interfaces. +type FirewallDBs interface { + RulesDB + PrivacyMapper + ActionDB +} diff --git a/firewalldb/kvstores_kvdb.go b/firewalldb/kvstores_kvdb.go index 51721d475..78676e3ed 100644 --- a/firewalldb/kvstores_kvdb.go +++ b/firewalldb/kvstores_kvdb.go @@ -16,13 +16,13 @@ the temporary store changes instead of just keeping an in-memory store is that we can then guarantee atomicity if changes are made to both the permanent and temporary stores. -rules -> perm -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} +"rules" -> "perm" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} - -> temp -> rule-name -> global -> {k:v} - -> sessions -> group ID -> session-kv-store -> {k:v} - -> feature-kv-stores -> feature-name -> {k:v} + -> "temp" -> -> "global" -> {k:v} + -> "session-kv-store" -> -> {k:v} + -> "feature-kv-stores" -> -> {k:v} */ var ( diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index 0c3df2ddb..f815742ac 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -9,8 +9,10 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be @@ -22,14 +24,41 @@ type SQLKVStoreQueries interface { DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlc.DeleteFeatureKVStoreRecordParams) error DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlc.DeleteGlobalKVStoreRecordParams) error - DeleteSessionKVStoreRecord(ctx context.Context, arg sqlc.DeleteSessionKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlc.DeleteGroupKVStoreRecordParams) error GetFeatureKVStoreRecord(ctx context.Context, arg sqlc.GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg sqlc.GetGlobalKVStoreRecordParams) ([]byte, error) - GetSessionKVStoreRecord(ctx context.Context, arg sqlc.GetSessionKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlc.GetGroupKVStoreRecordParams) ([]byte, error) UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlc.UpdateFeatureKVStoreRecordParams) error UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlc.UpdateGlobalKVStoreRecordParams) error - UpdateSessionKVStoreRecord(ctx context.Context, arg sqlc.UpdateSessionKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlc.UpdateGroupKVStoreRecordParams) error InsertKVStoreRecord(ctx context.Context, arg sqlc.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlc.Kvstore, error) + DeleteAllTempKVStores(ctx context.Context) error + GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) + GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetFeatureID(ctx context.Context, name string) (int64, error) + GetRuleID(ctx context.Context, name string) (int64, error) +} + +// SQLMig6KVStoreQueries is a subset of the sqlcmig6.Queries interface that can +// be used to interact with the kvstore tables. +//  +// +//nolint:lll +type SQLMig6KVStoreQueries interface { + SQLSessionQueries + + DeleteFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteFeatureKVStoreRecordParams) error + DeleteGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGlobalKVStoreRecordParams) error + DeleteGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.DeleteGroupKVStoreRecordParams) error + GetFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.GetFeatureKVStoreRecordParams) ([]byte, error) + GetGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGlobalKVStoreRecordParams) ([]byte, error) + GetGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.GetGroupKVStoreRecordParams) ([]byte, error) + UpdateFeatureKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateFeatureKVStoreRecordParams) error + UpdateGlobalKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGlobalKVStoreRecordParams) error + UpdateGroupKVStoreRecord(ctx context.Context, arg sqlcmig6.UpdateGroupKVStoreRecordParams) error + InsertKVStoreRecord(ctx context.Context, arg sqlcmig6.InsertKVStoreRecordParams) error + ListAllKVStoresRecords(ctx context.Context) ([]sqlcmig6.Kvstore, error) DeleteAllTempKVStores(ctx context.Context) error GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) @@ -45,7 +74,7 @@ func (s *SQLDB) DeleteTempKVStores(ctx context.Context) error { return s.db.ExecTx(ctx, &writeTxOpts, func(tx SQLQueries) error { return tx.DeleteAllTempKVStores(ctx) - }) + }, sqldb.NoOpReset) } // GetKVStores constructs a new rules.KVStores in a namespace defined by the @@ -198,7 +227,7 @@ func (s *sqlKVStore) Get(ctx context.Context, key string) ([]byte, error) { // // NOTE: part of the KVStore interface. func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, false) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, false) if err != nil { return err } @@ -219,7 +248,7 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { Value: value, Perm: s.params.perm, RuleID: ruleID, - SessionID: sessionID, + GroupID: groupID, FeatureID: featureID, }, ) @@ -233,26 +262,26 @@ func (s *sqlKVStore) Set(ctx context.Context, key string, value []byte) error { // Otherwise, the key exists but the value needs to be updated. switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.UpdateFeatureKVStoreRecord( ctx, sqlc.UpdateFeatureKVStoreRecordParams{ Key: key, Value: value, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.UpdateSessionKVStoreRecord( - ctx, sqlc.UpdateSessionKVStoreRecordParams{ - Key: key, - Value: value, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.UpdateGroupKVStoreRecord( + ctx, sqlc.UpdateGroupKVStoreRecordParams{ + Key: key, + Value: value, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -278,7 +307,7 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // Note: we pass in true here for "read-only" since because this is a // Delete, if the record does not exist, we don't need to create one. // But no need to error out if it doesn't exist. - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if errors.Is(err, sql.ErrNoRows) || errors.Is(err, session.ErrUnknownGroup) { @@ -288,24 +317,24 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.DeleteFeatureKVStoreRecord( ctx, sqlc.DeleteFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.DeleteSessionKVStoreRecord( - ctx, sqlc.DeleteSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.DeleteGroupKVStoreRecord( + ctx, sqlc.DeleteGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -326,30 +355,30 @@ func (s *sqlKVStore) Del(ctx context.Context, key string) error { // get fetches the value under the given key from the underlying kv store given // the namespace fields. func (s *sqlKVStore) get(ctx context.Context, key string) ([]byte, error) { - ruleID, sessionID, featureID, err := s.genNamespaceFields(ctx, true) + ruleID, groupID, featureID, err := s.genNamespaceFields(ctx, true) if err != nil { return nil, err } switch { - case sessionID.Valid && featureID.Valid: + case groupID.Valid && featureID.Valid: return s.queries.GetFeatureKVStoreRecord( ctx, sqlc.GetFeatureKVStoreRecordParams{ Key: key, Perm: s.params.perm, - SessionID: sessionID, + GroupID: groupID, RuleID: ruleID, FeatureID: featureID, }, ) - case sessionID.Valid: - return s.queries.GetSessionKVStoreRecord( - ctx, sqlc.GetSessionKVStoreRecordParams{ - Key: key, - Perm: s.params.perm, - SessionID: sessionID, - RuleID: ruleID, + case groupID.Valid: + return s.queries.GetGroupKVStoreRecord( + ctx, sqlc.GetGroupKVStoreRecordParams{ + Key: key, + Perm: s.params.perm, + GroupID: groupID, + RuleID: ruleID, }, ) @@ -373,7 +402,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, readOnly bool) (int64, sql.NullInt64, sql.NullInt64, error) { var ( - sessionID sql.NullInt64 + groupID sql.NullInt64 featureID sql.NullInt64 ruleID int64 err error @@ -382,8 +411,8 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, // If a group ID is specified, then we first check that this group ID // is a known session alias. s.params.groupID.WhenSome(func(id session.ID) { - var groupID int64 - groupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) + var dbGroupID int64 + dbGroupID, err = s.queries.GetSessionIDByAlias(ctx, id[:]) if errors.Is(err, sql.ErrNoRows) { err = session.ErrUnknownGroup @@ -392,20 +421,20 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, return } - sessionID = sql.NullInt64{ - Int64: groupID, + groupID = sql.NullInt64{ + Int64: dbGroupID, Valid: true, } }) if err != nil { - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } // We only insert a new rule name into the DB if this is a write call. if readOnly { ruleID, err = s.queries.GetRuleID(ctx, s.params.ruleName) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get rule ID: %w", err) } } else { @@ -413,7 +442,7 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, ctx, s.params.ruleName, ) if err != nil { - return 0, sessionID, featureID, + return 0, groupID, featureID, fmt.Errorf("unable to get or insert rule "+ "ID: %w", err) } @@ -441,5 +470,5 @@ func (s *sqlKVStore) genNamespaceFields(ctx context.Context, } }) - return ruleID, sessionID, featureID, err + return ruleID, groupID, featureID, err } diff --git a/firewalldb/privacy_mapper_sql.go b/firewalldb/privacy_mapper_sql.go index 8a4863a6c..ff7f70f94 100644 --- a/firewalldb/privacy_mapper_sql.go +++ b/firewalldb/privacy_mapper_sql.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightninglabs/lightning-terminal/session" ) @@ -22,6 +23,19 @@ type SQLPrivacyPairQueries interface { GetRealForPseudo(ctx context.Context, arg sqlc.GetRealForPseudoParams) (string, error) } +// SQLMig6PrivacyPairQueries is a subset of the sqlcmig6.Queries interface that +// can be used to interact with the privacy map table. +// +//nolint:lll +type SQLMig6PrivacyPairQueries interface { + SQLSessionQueries + + InsertPrivacyPair(ctx context.Context, arg sqlcmig6.InsertPrivacyPairParams) error + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]sqlcmig6.GetAllPrivacyPairsRow, error) + GetPseudoForReal(ctx context.Context, arg sqlcmig6.GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg sqlcmig6.GetRealForPseudoParams) (string, error) +} + // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. // diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go new file mode 100644 index 000000000..f898993e3 --- /dev/null +++ b/firewalldb/sql_migration.go @@ -0,0 +1,492 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" + "go.etcd.io/bbolt" +) + +// kvEntry represents a single KV entry inserted into the BoltDB. +type kvEntry struct { + perm bool + ruleName string + key string + value []byte + + // groupAlias is the legacy session group alias that the entry is + // associated with. For global entries, this will be fn.None[[]byte]. + groupAlias fn.Option[[]byte] + + // featureName is the name of the feature that the entry is associated + // with. If the entry is not feature specific, this will be + // fn.None[string]. + featureName fn.Option[string] +} + +// sqlKvEntry represents a single KV entry inserted into the SQL DB, containing +// the same fields as the kvEntry, but with additional fields that represent the +// SQL IDs of the rule, session group, and feature. +type sqlKvEntry struct { + *kvEntry + + ruleID int64 + + // groupID is the sql session group ID that the entry is associated + // with. For global entries, this will be Valid=false. + groupID sql.NullInt64 + + // featureID is the sql feature ID that the entry is associated with. + // This is only set if the entry is feature specific, and will be + // Valid=false for other types entries. If this is set, then groupID + // will also be set. + featureID sql.NullInt64 +} + +// namespacedKey returns a string representation of the kvEntry purely used for +// logging purposes. +func (e *kvEntry) namespacedKey() string { + ns := fmt.Sprintf("perm: %t, rule: %s", e.perm, e.ruleName) + + e.groupAlias.WhenSome(func(alias []byte) { + ns += fmt.Sprintf(", group: %s", alias) + }) + + e.featureName.WhenSome(func(feature string) { + ns += fmt.Sprintf(", feature: %s", feature) + }) + + ns += fmt.Sprintf(", key: %s", e.key) + + return ns +} + +// MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the +// bbolt database to a SQL database. The migration is done in a single +// transaction to ensure that all rows in the stores are migrated or none at +// all. +// +// Note that this migration currently only migrates the kvstores, but will be +// extended in the future to also migrate the privacy mapper and action stores. +// +// NOTE: As sessions may contain linked sessions and accounts, the sessions and +// accounts sql migration MUST be run prior to this migration. +func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLMig6Queries) error { + + log.Infof("Starting migration of the rules DB to SQL") + + err := migrateKVStoresDBToSQL(ctx, kvStore, sqlTx) + if err != nil { + return err + } + + log.Infof("The rules DB has been migrated from KV to SQL.") + + // TODO(viktor): Add migration for the privacy mapper and the action + // stores. + + return nil +} + +// migrateKVStoresDBToSQL runs the migration of all KV stores from the KV +// database to the SQL database. The function also asserts that the +// migrated values match the original values in the KV store. +func migrateKVStoresDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLMig6Queries) error { + + log.Infof("Starting migration of the KV stores to SQL") + + var pairs []*kvEntry + + // 1) Collect all key-value pairs from the KV store. + err := kvStore.View(func(tx *bbolt.Tx) error { + var err error + pairs, err = collectAllPairs(tx) + return err + }) + if err != nil { + return fmt.Errorf("collecting all kv pairs failed: %w", err) + } + + var insertedPairs []*sqlKvEntry + + // 2) Insert all collected key-value pairs into the SQL database. + for _, entry := range pairs { + insertedPair, err := insertPair(ctx, sqlTx, entry) + if err != nil { + return fmt.Errorf("inserting kv pair %v failed: %w", + entry.key, err) + } + + insertedPairs = append(insertedPairs, insertedPair) + } + + // 3) Validate the migrated values against the original values. + for _, insertedPair := range insertedPairs { + // Fetch the appropriate SQL entry's value. + migratedValue, err := getSQLValue(ctx, sqlTx, insertedPair) + if err != nil { + return fmt.Errorf("getting SQL value for key %s "+ + "failed: %w", insertedPair.namespacedKey(), err) + } + + // Compare the value of the migrated entry with the original + // value from the KV store. + // NOTE: if the insert a []byte{} value into the sqldb as the + // entry value, and then retrieve it, the value will be + // returned as nil. The bytes.Equal will pass in that case, + // and therefore such cases won't error out. The kvdb instance + // can store []byte{} values. + if !bytes.Equal(migratedValue, insertedPair.value) { + return fmt.Errorf("migrated value for key %s "+ + "does not match original value: "+ + "migrated %x, original %x", + insertedPair.namespacedKey(), migratedValue, + insertedPair.value) + } + } + + log.Infof("Migration of the KV stores to SQL completed. Total number "+ + "of rows migrated: %d", len(pairs)) + + return nil +} + +// collectAllPairs collects all key-value pairs from the KV store, and returns +// them as a slice of kvEntry structs. The function expects the KV store to be +// stuctured as described in the comment in the firewalldb/kvstores_kvdb.go +// file. Any other structure will result in an error. +// Note that this function and the subsequent functions are intentionally +// designed to iterate over all buckets and values that exist in the KV store. +// That ensures that we find all stores and values that exist in the KV store, +// and can be sure that the kv store actually follows the expected structure. +func collectAllPairs(tx *bbolt.Tx) ([]*kvEntry, error) { + var entries []*kvEntry + for _, perm := range []bool{true, false} { + mainBucket, err := getMainBucket(tx, false, perm) + if err != nil { + return nil, err + } + + if mainBucket == nil { + // If the mainBucket doesn't exist, there are no entries + // to migrate under that bucket, therefore we don't + // error, and just proceed to not migrate any entries + // under that bucket. + continue + } + + // Loop over each rule-name bucket. + err = mainBucket.ForEach(func(rule, v []byte) error { + if v != nil { + return errors.New("expected only " + + "buckets under main bucket") + } + + ruleBucket := mainBucket.Bucket(rule) + if ruleBucket == nil { + return fmt.Errorf("rule bucket %s not found", + rule) + } + + pairs, err := collectRulePairs( + ruleBucket, perm, string(rule), + ) + if err != nil { + return err + } + + entries = append(entries, pairs...) + + return nil + }) + if err != nil { + return nil, err + } + } + + return entries, nil +} + +// collectRulePairs processes a single rule bucket, which should contain the +// global and session-kv-store key buckets. +func collectRulePairs(bkt *bbolt.Bucket, perm bool, rule string) ([]*kvEntry, + error) { + + var params []*kvEntry + + err := verifyBktKeys( + bkt, true, globalKVStoreBucketKey, sessKVStoreBucketKey, + ) + if err != nil { + return params, fmt.Errorf("verifying rule bucket %s keys "+ + "failed: %w", rule, err) + } + + if globalBkt := bkt.Bucket(globalKVStoreBucketKey); globalBkt != nil { + p, err := collectKVPairs( + globalBkt, true, perm, rule, + fn.None[[]byte](), fn.None[string](), + ) + if err != nil { + return nil, fmt.Errorf("collecting global kv pairs "+ + "failed: %w", err) + } + + params = append(params, p...) + } + + if sessBkt := bkt.Bucket(sessKVStoreBucketKey); sessBkt != nil { + err := sessBkt.ForEach(func(groupAlias, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets "+ + "under %s bucket", sessKVStoreBucketKey) + } + + groupBucket := sessBkt.Bucket(groupAlias) + if groupBucket == nil { + return fmt.Errorf("group bucket for group "+ + "alias %s not found", groupAlias) + } + + kvPairs, err := collectKVPairs( + groupBucket, false, perm, rule, + fn.Some(groupAlias), fn.None[string](), + ) + if err != nil { + return fmt.Errorf("collecting group kv "+ + "pairs failed: %w", err) + } + + params = append(params, kvPairs...) + + err = verifyBktKeys( + groupBucket, false, featureKVStoreBucketKey, + ) + if err != nil { + return fmt.Errorf("verification of group "+ + "bucket %s keys failed: %w", groupAlias, + err) + } + + ftBkt := groupBucket.Bucket(featureKVStoreBucketKey) + if ftBkt == nil { + return nil + } + + return ftBkt.ForEach(func(ftName, v []byte) error { + if v != nil { + return fmt.Errorf("expected only "+ + "buckets under %s bucket", + featureKVStoreBucketKey) + } + + // The feature name should exist, as per the + // verification above. + featureBucket := ftBkt.Bucket(ftName) + if featureBucket == nil { + return fmt.Errorf("feature bucket "+ + "%s not found", ftName) + } + + featurePairs, err := collectKVPairs( + featureBucket, true, perm, rule, + fn.Some(groupAlias), + fn.Some(string(ftName)), + ) + if err != nil { + return fmt.Errorf("collecting "+ + "feature kv pairs failed: %w", + err) + } + + params = append(params, featurePairs...) + + return nil + }) + }) + if err != nil { + return nil, fmt.Errorf("collecting session kv pairs "+ + "failed: %w", err) + } + } + + return params, nil +} + +// collectKVPairs collects all key-value pairs from the given bucket, and +// returns them as a slice of kvEntry structs. If the errorOnBuckets parameter +// is set to true, then the function will return an error if the bucket +// contains any sub-buckets. Note that when the errorOnBuckets parameter is +// set to false, the function will not collect any key-value pairs from the +// sub-buckets, and will just ignore them. +func collectKVPairs(bkt *bbolt.Bucket, errorOnBuckets, perm bool, + ruleName string, groupAlias fn.Option[[]byte], + featureName fn.Option[string]) ([]*kvEntry, error) { + + var params []*kvEntry + + return params, bkt.ForEach(func(key, value []byte) error { + // If the value is nil, then this is a bucket, which we + // don't want to process here, as we only want to collect + // the key-value pairs, not the buckets. If we should + // error on buckets, then we return an error here. + if value == nil { + if errorOnBuckets { + return fmt.Errorf("unexpected bucket %s found "+ + "in when collecting kv pairs", key) + } + + return nil + } + + params = append(params, &kvEntry{ + perm: perm, + ruleName: ruleName, + key: string(key), + featureName: featureName, + groupAlias: groupAlias, + value: value, + }) + + return nil + }) +} + +// insertPair inserts a single key-value pair into the SQL database. +func insertPair(ctx context.Context, tx SQLMig6Queries, + entry *kvEntry) (*sqlKvEntry, error) { + + ruleID, err := tx.GetOrInsertRuleID(ctx, entry.ruleName) + if err != nil { + return nil, err + } + + p := sqlcmig6.InsertKVStoreRecordParams{ + Perm: entry.perm, + RuleID: ruleID, + EntryKey: entry.key, + Value: entry.value, + } + + entry.groupAlias.WhenSome(func(alias []byte) { + var groupID int64 + groupID, err = tx.GetSessionIDByAlias(ctx, alias) + if err != nil { + err = fmt.Errorf("getting group id by alias %x "+ + "failed: %w", alias, err) + return + } + + p.GroupID = sqldb.SQLInt64(groupID) + }) + if err != nil { + return nil, err + } + + entry.featureName.WhenSome(func(feature string) { + var featureID int64 + featureID, err = tx.GetOrInsertFeatureID(ctx, feature) + if err != nil { + err = fmt.Errorf("getting/inserting feature id for %s "+ + "failed: %w", feature, err) + return + } + + p.FeatureID = sqldb.SQLInt64(featureID) + }) + if err != nil { + return nil, err + } + + err = tx.InsertKVStoreRecord(ctx, p) + if err != nil { + return nil, err + } + + return &sqlKvEntry{ + kvEntry: entry, + ruleID: p.RuleID, + groupID: p.GroupID, + featureID: p.FeatureID, + }, nil +} + +// getSQLValue retrieves the key value for the given kvEntry from the SQL +// database. +func getSQLValue(ctx context.Context, tx SQLMig6Queries, + entry *sqlKvEntry) ([]byte, error) { + + switch { + case entry.featureID.Valid && entry.groupID.Valid: + return tx.GetFeatureKVStoreRecord( + ctx, sqlcmig6.GetFeatureKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + FeatureID: entry.featureID, + Key: entry.key, + }, + ) + case entry.groupID.Valid: + return tx.GetGroupKVStoreRecord( + ctx, sqlcmig6.GetGroupKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + GroupID: entry.groupID, + Key: entry.key, + }, + ) + case !entry.featureID.Valid && !entry.groupID.Valid: + return tx.GetGlobalKVStoreRecord( + ctx, sqlcmig6.GetGlobalKVStoreRecordParams{ + Perm: entry.perm, + RuleID: entry.ruleID, + Key: entry.key, + }, + ) + default: + return nil, fmt.Errorf("invalid combination of feature and "+ + "session ID: featureID valid: %v, groupID valid: %v", + entry.featureID.Valid, entry.groupID.Valid) + } +} + +// verifyBktKeys checks that the given bucket only contains buckets with the +// passed keys, and optionally also key-value pairs. If the errorOnKeyValues +// parameter is set to true, the function will error if it finds key-value pairs +// in the bucket. +func verifyBktKeys(bkt *bbolt.Bucket, errorOnKeyValues bool, + keys ...[]byte) error { + + return bkt.ForEach(func(key, v []byte) error { + if v != nil { + // If we allow key-values, then we can just continue + // to the next key. Else we need to error out, as we + // only expect buckets under the passed bucket. + if errorOnKeyValues { + return fmt.Errorf("unexpected key-value pair "+ + "found: key=%s, value=%x", key, v) + } + + return nil + } + + for _, expectedKey := range keys { + if bytes.Equal(key, expectedKey) { + // If this is an expected key, we can continue + // to the next key. + return nil + } + } + + return fmt.Errorf("unexpected key found: %s", key) + }) +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go new file mode 100644 index 000000000..ab8ea01da --- /dev/null +++ b/firewalldb/sql_migration_test.go @@ -0,0 +1,664 @@ +package firewalldb + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" +) + +const ( + testRuleName = "test-rule" + testRuleName2 = "test-rule-2" + testFeatureName = "test-feature" + testFeatureName2 = "test-feature-2" + testEntryKey = "test-entry-key" + testEntryKey2 = "test-entry-key-2" + testEntryKey3 = "test-entry-key-3" + testEntryKey4 = "test-entry-key-4" +) + +var ( + testEntryValue = []byte{1, 2, 3} +) + +// TestFirewallDBMigration tests the migration of firewalldb from a bolt +// backend to a SQL database. Note that this test does not attempt to be a +// complete migration test. +// This test only tests the migration of the KV stores currently, but will +// be extended in the future to also test the migration of the privacy mapper +// and the actions store in the future. +func TestFirewallDBMigration(t *testing.T) { + t.Parallel() + + ctx := context.Background() + clock := clock.NewTestClock(time.Now()) + + // When using build tags that creates a kvdb store for NewTestDB, we + // skip this test as it is only applicable for postgres and sqlite tags. + store := NewTestDB(t, clock) + if _, ok := store.(*BoltDB); ok { + t.Skipf("Skipping Firewall DB migration test for kvdb build") + } + + makeSQLDB := func(t *testing.T, + sStore session.Store) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + testDBStore := NewTestDBWithSessions(t, sStore, clock) + + store, ok := testDBStore.(*SQLDB) + require.True(t, ok) + + baseDB := store.BaseDB + + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) + + return NewSQLMig6QueriesExecutor(baseDB, queries) + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores entries in the SQLDB match the original kv + // stores entries in the BoltDB. + assertMigrationResults := func(t *testing.T, store SQLMig6Queries, + kvEntries []*kvEntry) { + + var ( + ruleIDs = make(map[string]int64) + groupIDs = make(map[string]int64) + featureIDs = make(map[string]int64) + err error + ) + + getRuleID := func(ruleName string) int64 { + ruleID, ok := ruleIDs[ruleName] + if !ok { + ruleID, err = store.GetRuleID(ctx, ruleName) + require.NoError(t, err) + + ruleIDs[ruleName] = ruleID + } + + return ruleID + } + + getGroupID := func(groupAlias []byte) int64 { + groupID, ok := groupIDs[string(groupAlias)] + if !ok { + groupID, err = store.GetSessionIDByAlias( + ctx, groupAlias, + ) + require.NoError(t, err) + + groupIDs[string(groupAlias)] = groupID + } + + return groupID + } + + getFeatureID := func(featureName string) int64 { + featureID, ok := featureIDs[featureName] + if !ok { + featureID, err = store.GetFeatureID( + ctx, featureName, + ) + require.NoError(t, err) + + featureIDs[featureName] = featureID + } + + return featureID + } + + // First we extract all migrated kv entries from the SQLDB, + // in order to be able to compare them to the original kv + // entries, to ensure that the migration was successful. + sqlKvEntries, err := store.ListAllKVStoresRecords(ctx) + require.NoError(t, err) + require.Equal(t, len(kvEntries), len(sqlKvEntries)) + + // We then iterate over the original kv entries that were + // migrated from the BoltDB to the SQLDB, and assert that they + // match the migrated SQL kv entries. + // NOTE: when fetching kv entries that were inserted into the + // sql store with the entry value []byte{}, a nil value is + // returned. Therefore, require.Equal would error on such cases, + // while bytes.Equal would not. Therefore, the comparison below + // uses bytes.Equal to compare the values. + for _, entry := range kvEntries { + ruleID := getRuleID(entry.ruleName) + + if entry.groupAlias.IsNone() { + sqlVal, err := store.GetGlobalKVStoreRecord( + ctx, + sqlcmig6.GetGlobalKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } else if entry.featureName.IsNone() { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + + v, err := store.GetGroupKVStoreRecord( + ctx, + sqlcmig6.GetGroupKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, v), + ) + } else { + groupAlias := entry.groupAlias.UnwrapOrFail(t) + groupID := getGroupID(groupAlias[:]) + featureID := getFeatureID( + entry.featureName.UnwrapOrFail(t), + ) + + sqlVal, err := store.GetFeatureKVStoreRecord( + ctx, + sqlcmig6.GetFeatureKVStoreRecordParams{ + Key: entry.key, + Perm: entry.perm, + RuleID: ruleID, + GroupID: sql.NullInt64{ + Int64: groupID, + Valid: true, + }, + FeatureID: sql.NullInt64{ + Int64: featureID, + Valid: true, + }, + }, + ) + require.NoError(t, err) + // See docs for the loop above on why + // bytes.Equal is used here. + require.True( + t, bytes.Equal(entry.value, sqlVal), + ) + } + } + } + + // The tests slice contains all the tests that we will run for the + // migration of the firewalldb from a BoltDB to a SQLDB. + // Note that the tests currently only test the migration of the KV + // stores, but will be extended in the future to also test the migration + // of the privacy mapper and the actions store. + tests := []struct { + name string + populateDB func(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry + }{ + { + name: "empty", + populateDB: func(t *testing.T, ctx context.Context, + boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + // Don't populate the DB. + return make([]*kvEntry, 0) + }, + }, + { + name: "global entries", + populateDB: globalEntries, + }, + { + name: "session specific entries", + populateDB: sessionSpecificEntries, + }, + { + name: "feature specific entries", + populateDB: featureSpecificEntries, + }, + { + name: "all entry combinations", + populateDB: allEntryCombinations, + }, + { + name: "random entries", + populateDB: randomKVEntries, + }, + } + + for _, test := range tests { + tc := test + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // First let's create a sessions store to link to in + // the kvstores DB. In order to create the sessions + // store though, we also need to create an accounts + // store, that we link to the sessions store. + // Note that both of these stores will be sql stores due + // to the build tags enabled when running this test, + // which means we can also pass the sessions store to + // the sql version of the kv stores that we'll create + // in test, without also needing to migrate it. + accountStore := accounts.NewTestDB(t, clock) + sessionsStore := session.NewTestDBWithAccounts( + t, clock, accountStore, + ) + + // Create a new firewall store to populate with test + // data. + firewallStore, err := NewBoltDB( + t.TempDir(), DBFilename, sessionsStore, + accountStore, clock, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, firewallStore.Close()) + }) + + // Populate the kv store. + entries := test.populateDB( + t, ctx, firewallStore, sessionsStore, + ) + + // Create the SQL store that we will migrate the data + // to. + txEx := makeSQLDB(t, sessionsStore) + + // Perform the migration. + var opts sqldb.MigrationTxOptions + err = txEx.ExecTx(ctx, &opts, + func(tx SQLMig6Queries) error { + err = MigrateFirewallDBToSQL( + ctx, firewallStore.DB, tx, + ) + if err != nil { + return err + } + + // Assert migration results. + assertMigrationResults(t, tx, entries) + + return nil + }, sqldb.NoOpReset, + ) + require.NoError(t, err) + }) + } +} + +// globalEntries populates the kv store with one global entry for the temp +// store, and one for the perm store. +func globalEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + _ session.Store) []*kvEntry { + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + ) +} + +// sessionSpecificEntries populates the kv store with one session specific +// entry for the local temp store, and one session specific entry for the perm +// local store. +func sessionSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, fn.None[string](), + testEntryKey, testEntryValue, + ) +} + +// featureSpecificEntries populates the kv store with one feature specific +// entry for the local temp store, and one feature specific entry for the perm +// local store. +func featureSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + return insertTempAndPermEntry( + t, ctx, boltDB, testRuleName, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + ) +} + +// allEntryCombinations adds all types of different entries at all possible +// levels of the kvstores, including multple entries with the same +// ruleName, groupAlias and featureName. The test aims to cover all possible +// combinations of entries in the kvstores, including nil and empty entry +// values. That therefore ensures that the migrations don't overwrite or miss +// any entries when the entry set is more complex than just a single entry at +// each level. +func allEntryCombinations(t *testing.T, ctx context.Context, boltDB *BoltDB, + sessionStore session.Store) []*kvEntry { + + var result []*kvEntry + add := func(entry []*kvEntry) { + result = append(result, entry...) + } + + // First lets create standard entries at all levels, which represents + // the entries added by other tests. + add(globalEntries(t, ctx, boltDB, sessionStore)) + add(sessionSpecificEntries(t, ctx, boltDB, sessionStore)) + add(featureSpecificEntries(t, ctx, boltDB, sessionStore)) + + groupAlias := getNewSessionAlias(t, ctx, sessionStore) + + // Now lets add a few more entries at with different rule names and + // features, just to ensure that we cover entries in different rule and + // feature tables. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey, testEntryValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey, testEntryValue, + )) + // Let's also create an entry with a different feature name that's still + // referencing the same group ID as the previous entry. + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName2), testEntryKey, testEntryValue, + )) + + // Finally, lets add a few entries with nil and empty values set for the + // actual key value, at all different levels, to ensure that tests don't + // break if the value is nil or empty. + var ( + nilValue []byte = nil + nilSliceValue = []byte(nil) + emptyValue = []byte{} + ) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, fn.None[[]byte](), + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.None[string](), testEntryKey4, emptyValue, + )) + + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey2, nilValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey3, nilSliceValue, + )) + add(insertTempAndPermEntry( + t, ctx, boltDB, testRuleName2, groupAlias, + fn.Some(testFeatureName), testEntryKey4, emptyValue, + )) + + return result +} + +func getNewSessionAlias(t *testing.T, ctx context.Context, + sessionStore session.Store) fn.Option[[]byte] { + + sess, err := sessionStore.NewSession( + ctx, "test", session.TypeAutopilot, + time.Unix(1000, 0), "something", + ) + require.NoError(t, err) + + return fn.Some(sess.GroupID[:]) +} + +// insertTempAndPermEntry populates the kv store with one entry for the temp +// store, and one entry for the perm store. Both of the entries will be inserted +// with the same groupAlias, ruleName, entryKey and entryValue. +func insertTempAndPermEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, ruleName string, groupAlias fn.Option[[]byte], + featureNameOpt fn.Option[string], entryKey string, + entryValue []byte) []*kvEntry { + + tempKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: false, + } + + insertKvEntry(t, ctx, boltDB, tempKvEntry) + + permKvEntry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAlias, + featureName: featureNameOpt, + key: entryKey, + value: entryValue, + perm: true, + } + + insertKvEntry(t, ctx, boltDB, permKvEntry) + + return []*kvEntry{tempKvEntry, permKvEntry} +} + +// insertKvEntry populates the kv store with passed entry, and asserts that the +// entry is inserted correctly. +func insertKvEntry(t *testing.T, ctx context.Context, + boltDB *BoltDB, entry *kvEntry) { + + if entry.groupAlias.IsNone() && entry.featureName.IsSome() { + t.Fatalf("cannot set both global and feature specific at the " + + "same time") + } + + // We get the kv stores that the entry will be inserted into. Note that + // we set an empty group ID if the entry is global, as the group ID + // will not be used when fetching the actual kv store that's used for + // global entries. + groupID := [4]byte{} + if entry.groupAlias.IsSome() { + copy(groupID[:], entry.groupAlias.UnwrapOrFail(t)) + } + + kvStores := boltDB.GetKVStores( + entry.ruleName, groupID, entry.featureName.UnwrapOr(""), + ) + + err := kvStores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + + store := tx.Global() + + switch { + case entry.groupAlias.IsNone() && !entry.perm: + store = tx.GlobalTemp() + case entry.groupAlias.IsSome() && !entry.perm: + store = tx.LocalTemp() + case entry.groupAlias.IsSome() && entry.perm: + store = tx.Local() + } + + return store.Set(ctx, entry.key, entry.value) + }) + require.NoError(t, err) +} + +// randomKVEntries populates the kv store with random kv entries that span +// across all possible combinations of different levels of entries in the kv +// store. All values and different bucket names are randomly generated. +func randomKVEntries(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) []*kvEntry { + + var ( + // We set the number of entries to insert to 1000, as that + // should be enough to cover as many different + // combinations of entries as possible, while still being + // fast enough to run in a reasonable time. + numberOfEntries = 1000 + insertedEntries = make([]*kvEntry, 0) + ruleName = "initial-rule" + groupAlias []byte + featureName = "initial-feature" + ) + + // Create a random session that we can reference for the initial group + // ID. + sess, err := sessionStore.NewSession( + ctx, "initial-session", session.Type(1), time.Unix(1000, 0), + "serverAddr.test", + ) + require.NoError(t, err) + + groupAlias = sess.GroupID[:] + + // Generate random entries. Note that many entries will use the same + // rule name, group ID and feature name, to simulate the real world + // usage of the kv stores as much as possible. + for i := 0; i < numberOfEntries; i++ { + // On average, we will generate a new rule which will be used + // for the kv store entry 10% of the time. + if rand.Intn(10) == 0 { + ruleName = fmt.Sprintf( + "rule-%s-%d", randomString(rand.Intn(30)+1), i, + ) + } + + // On average, we use the global store 25% of the time. + global := rand.Intn(4) == 0 + + // We'll use the perm store 50% of the time. + perm := rand.Intn(2) == 0 + + // For the non-global entries, we will generate a new group + // alias 25% of the time. + if !global && rand.Intn(4) == 0 { + newSess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), + randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupAlias = newSess.GroupID[:] + } + + featureNameOpt := fn.None[string]() + + // For 50% of the non-global entries, we insert a feature + // specific entry. The other 50% will be session specific + // entries. + if !global && rand.Intn(2) == 0 { + // 25% of the time, we will generate a new feature name. + if rand.Intn(4) == 0 { + featureName = fmt.Sprintf( + "feature-%s-%d", + randomString(rand.Intn(30)+1), i, + ) + } + + featureNameOpt = fn.Some(featureName) + } + + groupAliasOpt := fn.None[[]byte]() + if !global { + // If the entry is not global, we set the group ID + // to the latest session's group ID. + groupAliasOpt = fn.Some(groupAlias[:]) + } + + entry := &kvEntry{ + ruleName: ruleName, + groupAlias: groupAliasOpt, + featureName: featureNameOpt, + key: fmt.Sprintf("key-%d", i), + perm: perm, + } + + // When setting a value for the entry, 25% of the time, we will + // set a nil or empty value. + if rand.Intn(4) == 0 { + // in 50% of these cases, we will set the value to nil, + // and in the other 50% we will set it to an empty + // value + if rand.Intn(2) == 0 { + entry.value = nil + } else { + entry.value = []byte{} + } + } else { + // Else generate a random value for all entries, + entry.value = []byte(randomString(rand.Intn(100) + 1)) + } + + // Insert the entry into the kv store. + insertKvEntry(t, ctx, boltDB, entry) + + // Add the entry to the list of inserted entries. + insertedEntries = append(insertedEntries, entry) + } + + return insertedEntries +} + +// randomString generates a random string of the passed length n. +func randomString(n int) string { + letterBytes := "abcdefghijklmnopqrstuvwxyz" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index f17010f2c..dcf201f1c 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -5,7 +5,10 @@ import ( "database/sql" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // SQLSessionQueries is a subset of the sqlc.Queries interface that can be used @@ -18,17 +21,30 @@ type SQLSessionQueries interface { // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { + sqldb.BaseQuerier + SQLKVStoreQueries SQLPrivacyPairQueries SQLActionQueries } -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used +// to interact with various firewalldb tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + SQLMig6KVStoreQueries + SQLMig6PrivacyPairQueries + SQLMig6ActionQueries +} + +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLDB represents a storage backend. @@ -38,11 +54,51 @@ type SQLDB struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) + }, + ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + // A compile-time assertion to ensure that SQLDB implements the RulesDB // interface. var _ RulesDB = (*SQLDB)(nil) @@ -53,12 +109,10 @@ var _ ActionDB = (*SQLDB)(nil) // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. -func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) - }, - ) +func NewSQLDB(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLDB { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLDB{ db: executor, @@ -88,7 +142,7 @@ func (e *sqlExecutor[T]) Update(ctx context.Context, var txOpts db.QueriesTxOptions return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } // View opens a database read transaction and executes the function f with the @@ -104,5 +158,5 @@ func (e *sqlExecutor[T]) View(ctx context.Context, return e.db.ExecTx(ctx, &txOpts, func(queries SQLQueries) error { return fn(ctx, e.wrapTx(queries)) - }) + }, sqldb.NoOpReset) } diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 6f7a49aa3..c3cd4533a 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -6,34 +6,37 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *BoltDB { +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { return NewTestDBFromPath(t, t.TempDir(), clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *BoltDB { +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + return newDBFromPathWithSessions(t, dbPath, nil, nil, clock) } // NewTestDBWithSessions creates a new test BoltDB Store with access to an // existing sessions DB. -func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, - clock clock.Clock) *BoltDB { +func NewTestDBWithSessions(t *testing.T, sessStore session.Store, + clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } // NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access // to an existing sessions DB and accounts DB. -func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *BoltDB { +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore session.Store, + acctStore AccountsDB, clock clock.Clock) FirewallDBs { return newDBFromPathWithSessions( t, t.TempDir(), sessStore, acctStore, clock, @@ -41,7 +44,8 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, } func newDBFromPathWithSessions(t *testing.T, dbPath string, - sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { + sessStore session.Store, acctStore AccountsDB, + clock clock.Clock) FirewallDBs { store, err := NewBoltDB(dbPath, DBFilename, sessStore, acctStore, clock) require.NoError(t, err) diff --git a/firewalldb/test_postgres.go b/firewalldb/test_postgres.go index f5777e4cb..732b19b4a 100644 --- a/firewalldb/test_postgres.go +++ b/firewalldb/test_postgres.go @@ -10,12 +10,12 @@ import ( ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestPostgresDB(t).BaseDB, clock) +func NewTestDBFromPath(t *testing.T, _ string, clock clock.Clock) FirewallDBs { + return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock) } diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 03dcfbebf..b7e3d9052 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -6,27 +6,29 @@ import ( "testing" "time" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) // NewTestDBWithSessions creates a new test SQLDB Store with access to an // existing sessions DB. func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, - clock clock.Clock) *SQLDB { - + clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } // NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access // to an existing sessions DB and accounts DB. func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, - acctStore AccountsDB, clock clock.Clock) *SQLDB { + acctStore AccountsDB, clock clock.Clock) FirewallDBs { sessions, ok := sessionStore.(*session.SQLStore) require.True(t, ok) @@ -36,7 +38,7 @@ func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, require.Equal(t, accounts.BaseDB, sessions.BaseDB) - return NewSQLDB(sessions.BaseDB, clock) + return createStore(t, sessions.BaseDB, clock) } func assertEqualActions(t *testing.T, expected, got *Action) { @@ -52,3 +54,16 @@ func assertEqualActions(t *testing.T, expected, got *Action) { expected.AttemptedAt = expectedAttemptedAt got.AttemptedAt = actualAttemptedAt } + +// createStore is a helper function that creates a new SQLDB and ensure that +// it is closed when during the test cleanup. +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, clock clock.Clock) *SQLDB { + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLDB(sqlDB, queries, clock) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/firewalldb/test_sqlite.go b/firewalldb/test_sqlite.go index 5496cb205..3f91546af 100644 --- a/firewalldb/test_sqlite.go +++ b/firewalldb/test_sqlite.go @@ -7,17 +7,26 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // NewTestDB is a helper function that creates an BBolt database for testing. -func NewTestDB(t *testing.T, clock clock.Clock) *SQLDB { - return NewSQLDB(db.NewTestSqliteDB(t).BaseDB, clock) +func NewTestDB(t *testing.T, clock clock.Clock) FirewallDBs { + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new BoltStore with a // connection to an existing BBolt database for testing. -func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) *SQLDB { - return NewSQLDB( - db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) FirewallDBs { + + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/go.mod b/go.mod index c0a0c05c5..cdcb07912 100644 --- a/go.mod +++ b/go.mod @@ -34,14 +34,15 @@ require ( github.com/lightninglabs/pool/poolrpc v1.0.1 github.com/lightninglabs/taproot-assets v0.6.0-rc3 github.com/lightninglabs/taproot-assets/taprpc v1.0.6 - github.com/lightningnetwork/lnd v0.19.1-beta.rc1 + github.com/lightningnetwork/lnd v0.19.1-beta github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 github.com/lightningnetwork/lnd/fn v1.2.3 github.com/lightningnetwork/lnd/fn/v2 v2.0.8 github.com/lightningnetwork/lnd/kvdb v1.4.16 - github.com/lightningnetwork/lnd/sqldb v1.0.9 - github.com/lightningnetwork/lnd/tlv v1.3.1 + github.com/lightningnetwork/lnd/sqldb v1.0.10 + github.com/lightningnetwork/lnd/sqldb/v2 v2.0.0-00010101000000-000000000000 + github.com/lightningnetwork/lnd/tlv v1.3.2 github.com/lightningnetwork/lnd/tor v1.1.6 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f github.com/mwitkow/grpc-proxy v0.0.0-20230212185441-f345521cb9c9 @@ -247,3 +248,8 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // it is a replace in the tapd repository, it doesn't get propagated here // automatically, so we need to add it manually. replace github.com/golang-migrate/migrate/v4 => github.com/lightninglabs/migrate/v4 v4.18.2-9023d66a-fork-pr-2 + +replace github.com/lightningnetwork/lnd => github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b + +// TODO: replace this with your own local fork +replace github.com/lightningnetwork/lnd/sqldb/v2 => ../../lnd_forked/lnd/sqldb diff --git a/go.sum b/go.sum index 36adbb9e0..df0a58695 100644 --- a/go.sum +++ b/go.sum @@ -616,6 +616,8 @@ github.com/NebulousLabs/go-upnp v0.0.0-20180202185039-29b680b06c82/go.mod h1:Gbu github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b h1:2UbeVKTtFvxxtHTzHOmQJvBeQ87eTdYMrAs4O8ZuIi8= +github.com/ViktorTigerstrom/lnd v0.0.0-20250710121612-a88fb038013b/go.mod h1:FLpPqYTU7KKJ87mNsqp/DsJwmpenK0zhuLBO+beWCl4= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 h1:cDVUiFo+npB0ZASqnw4q90ylaVAbnYyx0JYqK4YcGok= github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344/go.mod h1:9pIqrY6SXNL8vjRQE5Hd/OL5GyK/9MrGUWs87z/eFfk= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= @@ -1178,8 +1180,6 @@ github.com/lightninglabs/taproot-assets/taprpc v1.0.6 h1:h8tf4y7U5/3A9WNAs7HBTL8 github.com/lightninglabs/taproot-assets/taprpc v1.0.6/go.mod h1:vOM2Ap2wYhEZjiJU7bNNg+e5tDxkvRAuyXwf/KQ4tgo= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb h1:yfM05S8DXKhuCBp5qSMZdtSwvJ+GFzl94KbXMNB1JDY= github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb/go.mod h1:c0kvRShutpj3l6B9WtTsNTBUtjSmjZXbJd9ZBRQOSKI= -github.com/lightningnetwork/lnd v0.19.1-beta.rc1 h1:VV7xpS1g7OpDhlGYIa50Ac4I7TDZhvHjZ2Mmf+L4a7Y= -github.com/lightningnetwork/lnd v0.19.1-beta.rc1/go.mod h1:iHZ/FHFK00BqV6qgDkZZfqWE3LGtgE0U5KdO5WrM+eQ= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= @@ -1194,12 +1194,12 @@ github.com/lightningnetwork/lnd/kvdb v1.4.16 h1:9BZgWdDfjmHRHLS97cz39bVuBAqMc4/p github.com/lightningnetwork/lnd/kvdb v1.4.16/go.mod h1:HW+bvwkxNaopkz3oIgBV6NEnV4jCEZCACFUcNg4xSjM= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= -github.com/lightningnetwork/lnd/sqldb v1.0.9 h1:7OHi+Hui823mB/U9NzCdlZTAGSVdDCbjp33+6d/Q+G0= -github.com/lightningnetwork/lnd/sqldb v1.0.9/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4= +github.com/lightningnetwork/lnd/sqldb v1.0.10 h1:ZLV7TGwjnKupVfCd+DJ43MAc9BKVSFCnvhpSPGKdN3M= +github.com/lightningnetwork/lnd/sqldb v1.0.10/go.mod h1:c/vWoQfcxu6FAfHzGajkIQi7CEIeIZFhhH4DYh1BJpc= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.3.1 h1:o7CZg06y+rJZfUMAo0WzBLr0pgBWCzrt0f9gpujYUzk= -github.com/lightningnetwork/lnd/tlv v1.3.1/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= +github.com/lightningnetwork/lnd/tlv v1.3.2 h1:MO4FCk7F4k5xPMqVZF6Nb/kOpxlwPrUQpYjmyKny5s0= +github.com/lightningnetwork/lnd/tlv v1.3.2/go.mod h1:pJuiBj1ecr1WWLOtcZ+2+hu9Ey25aJWFIsjmAoPPnmc= github.com/lightningnetwork/lnd/tor v1.1.6 h1:WHUumk7WgU6BUFsqHuqszI9P6nfhMeIG+rjJBlVE6OE= github.com/lightningnetwork/lnd/tor v1.1.6/go.mod h1:qSRB8llhAK+a6kaTPWOLLXSZc6Hg8ZC0mq1sUQ/8JfI= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= diff --git a/log.go b/log.go index b803f72d2..0535a819a 100644 --- a/log.go +++ b/log.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/migrationstreams" "github.com/lightninglabs/lightning-terminal/firewall" "github.com/lightninglabs/lightning-terminal/firewalldb" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -91,6 +92,10 @@ func SetupLoggers(root *build.SubLoggerManager, intercept signal.Interceptor) { root, subservers.Subsystem, intercept, subservers.UseLogger, ) lnd.AddSubLogger(root, db.Subsystem, intercept, db.UseLogger) + lnd.AddSubLogger( + root, migrationstreams.Subsystem, intercept, + migrationstreams.UseLogger, + ) // Add daemon loggers to lnd's root logger. faraday.SetupLoggers(root, intercept) diff --git a/scripts/gen_sqlc_docker.sh b/scripts/gen_sqlc_docker.sh index 16db97f2c..3d93f37ff 100755 --- a/scripts/gen_sqlc_docker.sh +++ b/scripts/gen_sqlc_docker.sh @@ -5,7 +5,7 @@ set -e # restore_files is a function to restore original schema files. restore_files() { echo "Restoring SQLite bigint patch..." - for file in db/sqlc/migrations/*.up.sql.bak; do + for file in db/sqlc/{migrations,migrations_dev}/*.up.sql.bak; do mv "$file" "${file%.bak}" done } @@ -30,7 +30,7 @@ GOMODCACHE=$(go env GOMODCACHE) # source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER # PRIMARY KEY". echo "Applying SQLite bigint patch..." -for file in db/sqlc/migrations/*.up.sql; do +for file in db/sqlc/{migrations,migrations_dev}/*.up.sql; do echo "Patching $file" sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file" done diff --git a/session/sql_migration.go b/session/sql_migration.go index 428cc0fce..b1caebeb4 100644 --- a/session/sql_migration.go +++ b/session/sql_migration.go @@ -9,12 +9,17 @@ import ( "reflect" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" + "github.com/lightninglabs/lightning-node-connect/mailbox" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/sqldb" "github.com/pmezard/go-difflib/difflib" "go.etcd.io/bbolt" + "gopkg.in/macaroon-bakery.v2/bakery" + "gopkg.in/macaroon.v2" ) var ( @@ -31,7 +36,7 @@ var ( // NOTE: As sessions may contain linked accounts, the accounts sql migration // MUST be run prior to this migration. func MigrateSessionStoreToSQL(ctx context.Context, kvStore *bbolt.DB, - tx SQLQueries) error { + tx SQLMig6Queries) error { log.Infof("Starting migration of the KV sessions store to SQL") @@ -118,7 +123,7 @@ func getBBoltSessions(db *bbolt.DB) ([]*Session, error) { // from the KV database to the SQL database, and validates that the migrated // sessions match the original sessions. func migrateSessionsToSQLAndValidate(ctx context.Context, - tx SQLQueries, kvSessions []*Session) error { + tx SQLMig6Queries, kvSessions []*Session) error { for _, kvSession := range kvSessions { err := migrateSingleSessionToSQL(ctx, tx, kvSession) @@ -127,18 +132,9 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, kvSession.ID, err) } - // Validate that the session was correctly migrated and matches - // the original session in the kv store. - sqlSess, err := tx.GetSessionByAlias(ctx, kvSession.ID[:]) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - err = ErrSessionNotFound - } - return fmt.Errorf("unable to get migrated session "+ - "from sql store: %w", err) - } - - migratedSession, err := unmarshalSession(ctx, tx, sqlSess) + migratedSession, err := getAndUnmarshalSession( + ctx, tx, kvSession.ID[:], + ) if err != nil { return fmt.Errorf("unable to unmarshal migrated "+ "session: %w", err) @@ -172,12 +168,206 @@ func migrateSessionsToSQLAndValidate(ctx context.Context, return nil } +func getAndUnmarshalSession(ctx context.Context, + tx SQLMig6Queries, legacyID []byte) (*Session, error) { + + // Validate that the session was correctly migrated and matches + // the original session in the kv store. + sqlSess, err := tx.GetSessionByAlias(ctx, legacyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = ErrSessionNotFound + } + + return nil, fmt.Errorf("unable to get migrated session "+ + "from sql store: %w", err) + } + + migratedSession, err := unmarshalMig6Session(ctx, tx, sqlSess) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal migrated "+ + "session: %w", err) + } + + return migratedSession, nil + +} + +func unmarshalMig6Session(ctx context.Context, db SQLMig6Queries, + dbSess sqlcmig6.Session) (*Session, error) { + + var legacyGroupID ID + if dbSess.GroupID.Valid { + groupID, err := db.GetAliasBySessionID( + ctx, dbSess.GroupID.Int64, + ) + if err != nil { + return nil, fmt.Errorf("unable to get legacy group "+ + "Alias: %v", err) + } + + legacyGroupID, err = IDFromBytes(groupID) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", + err) + } + } + + var acctAlias fn.Option[accounts.AccountID] + if dbSess.AccountID.Valid { + account, err := db.GetAccount(ctx, dbSess.AccountID.Int64) + if err != nil { + return nil, fmt.Errorf("unable to get account: %v", err) + } + + accountAlias, err := accounts.AccountIDFromInt64(account.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get account ID: %v", err) + } + acctAlias = fn.Some(accountAlias) + } + + legacyID, err := IDFromBytes(dbSess.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get legacy Alias: %v", err) + } + + var revokedAt time.Time + if dbSess.RevokedAt.Valid { + revokedAt = dbSess.RevokedAt.Time + } + + localPriv, localPub := btcec.PrivKeyFromBytes(dbSess.LocalPrivateKey) + + var remotePub *btcec.PublicKey + if len(dbSess.RemotePublicKey) != 0 { + remotePub, err = btcec.ParsePubKey(dbSess.RemotePublicKey) + if err != nil { + return nil, fmt.Errorf("unable to parse remote "+ + "public key: %v", err) + } + } + + // Get the macaroon permissions if they exist. + perms, err := db.GetSessionMacaroonPermissions(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "permissions: %v", err) + } + + // Get the macaroon caveats if they exist. + caveats, err := db.GetSessionMacaroonCaveats(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get macaroon "+ + "caveats: %v", err) + } + + var macRecipe *MacaroonRecipe + if perms != nil || caveats != nil { + macRecipe = &MacaroonRecipe{ + Permissions: unmarshalMig6MacPerms(perms), + Caveats: unmarshalMig6MacCaveats(caveats), + } + } + + // Get the feature configs if they exist. + featureConfigs, err := db.GetSessionFeatureConfigs(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get feature configs: %v", err) + } + + var featureCfgs *FeaturesConfig + if featureConfigs != nil { + featureCfgs = unmarshalMig6FeatureConfigs(featureConfigs) + } + + // Get the privacy flags if they exist. + privacyFlags, err := db.GetSessionPrivacyFlags(ctx, dbSess.ID) + if err != nil { + return nil, fmt.Errorf("unable to get privacy flags: %v", err) + } + + var privFlags PrivacyFlags + if privacyFlags != nil { + privFlags = unmarshalMig6PrivacyFlags(privacyFlags) + } + + var pairingSecret [mailbox.NumPassphraseEntropyBytes]byte + copy(pairingSecret[:], dbSess.PairingSecret) + + return &Session{ + ID: legacyID, + Label: dbSess.Label, + State: State(dbSess.State), + Type: Type(dbSess.Type), + Expiry: dbSess.Expiry, + CreatedAt: dbSess.CreatedAt, + RevokedAt: revokedAt, + ServerAddr: dbSess.ServerAddress, + DevServer: dbSess.DevServer, + MacaroonRootKey: uint64(dbSess.MacaroonRootKey), + PairingSecret: pairingSecret, + LocalPrivateKey: localPriv, + LocalPublicKey: localPub, + RemotePublicKey: remotePub, + WithPrivacyMapper: dbSess.Privacy, + GroupID: legacyGroupID, + PrivacyFlags: privFlags, + MacaroonRecipe: macRecipe, + FeatureConfig: featureCfgs, + AccountID: acctAlias, + }, nil +} + +func unmarshalMig6MacPerms(dbPerms []sqlcmig6.SessionMacaroonPermission) []bakery.Op { + ops := make([]bakery.Op, len(dbPerms)) + for i, dbPerm := range dbPerms { + ops[i] = bakery.Op{ + Entity: dbPerm.Entity, + Action: dbPerm.Action, + } + } + + return ops +} + +func unmarshalMig6MacCaveats(dbCaveats []sqlcmig6.SessionMacaroonCaveat) []macaroon.Caveat { + caveats := make([]macaroon.Caveat, len(dbCaveats)) + for i, dbCaveat := range dbCaveats { + caveats[i] = macaroon.Caveat{ + Id: dbCaveat.CaveatID, + VerificationId: dbCaveat.VerificationID, + Location: dbCaveat.Location.String, + } + } + + return caveats +} + +func unmarshalMig6FeatureConfigs(dbConfigs []sqlcmig6.SessionFeatureConfig) *FeaturesConfig { + configs := make(FeaturesConfig, len(dbConfigs)) + for _, dbConfig := range dbConfigs { + configs[dbConfig.FeatureName] = dbConfig.Config + } + + return &configs +} + +func unmarshalMig6PrivacyFlags(dbFlags []sqlcmig6.SessionPrivacyFlag) PrivacyFlags { + flags := make(PrivacyFlags, len(dbFlags)) + for i, dbFlag := range dbFlags { + flags[i] = PrivacyFlag(dbFlag.Flag) + } + + return flags +} + // migrateSingleSessionToSQL runs the migration for a single session from the // KV database to the SQL database. Note that if the session links to an // account, the linked accounts store MUST have been migrated before that // session is migrated. func migrateSingleSessionToSQL(ctx context.Context, - tx SQLQueries, session *Session) error { + tx SQLMig6Queries, session *Session) error { var ( acctID sql.NullInt64 @@ -213,7 +403,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Proceed to insert the session into the sql db. - sqlId, err := tx.InsertSession(ctx, sqlc.InsertSessionParams{ + sqlId, err := tx.InsertSession(ctx, sqlcmig6.InsertSessionParams{ Alias: session.ID[:], Label: session.Label, State: int16(session.State), @@ -239,7 +429,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // has been created. if !session.RevokedAt.IsZero() { err = tx.SetSessionRevokedAt( - ctx, sqlc.SetSessionRevokedAtParams{ + ctx, sqlcmig6.SetSessionRevokedAtParams{ ID: sqlId, RevokedAt: sqldb.SQLTime( session.RevokedAt.UTC(), @@ -265,7 +455,7 @@ func migrateSingleSessionToSQL(ctx context.Context, } // Now lets set the group ID for the session. - err = tx.SetSessionGroupID(ctx, sqlc.SetSessionGroupIDParams{ + err = tx.SetSessionGroupID(ctx, sqlcmig6.SetSessionGroupIDParams{ ID: sqlId, GroupID: sqldb.SQLInt64(groupID), }) @@ -279,7 +469,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // We start by inserting the macaroon permissions. for _, sessionPerm := range session.MacaroonRecipe.Permissions { err = tx.InsertSessionMacaroonPermission( - ctx, sqlc.InsertSessionMacaroonPermissionParams{ + ctx, sqlcmig6.InsertSessionMacaroonPermissionParams{ SessionID: sqlId, Entity: sessionPerm.Entity, Action: sessionPerm.Action, @@ -293,7 +483,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Next we insert the macaroon caveats. for _, caveat := range session.MacaroonRecipe.Caveats { err = tx.InsertSessionMacaroonCaveat( - ctx, sqlc.InsertSessionMacaroonCaveatParams{ + ctx, sqlcmig6.InsertSessionMacaroonCaveatParams{ SessionID: sqlId, CaveatID: caveat.Id, VerificationID: caveat.VerificationId, @@ -312,7 +502,7 @@ func migrateSingleSessionToSQL(ctx context.Context, if session.FeatureConfig != nil { for featureName, config := range *session.FeatureConfig { err = tx.InsertSessionFeatureConfig( - ctx, sqlc.InsertSessionFeatureConfigParams{ + ctx, sqlcmig6.InsertSessionFeatureConfigParams{ SessionID: sqlId, FeatureName: featureName, Config: config, @@ -327,7 +517,7 @@ func migrateSingleSessionToSQL(ctx context.Context, // Finally we insert the privacy flags. for _, privacyFlag := range session.PrivacyFlags { err = tx.InsertSessionPrivacyFlag( - ctx, sqlc.InsertSessionPrivacyFlagParams{ + ctx, sqlcmig6.InsertSessionPrivacyFlagParams{ SessionID: sqlId, Flag: int32(privacyFlag), }, diff --git a/session/sql_migration_test.go b/session/sql_migration_test.go index dfe495628..9ac75f30b 100644 --- a/session/sql_migration_test.go +++ b/session/sql_migration_test.go @@ -2,17 +2,16 @@ package session import ( "context" - "database/sql" "fmt" "testing" "time" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" - "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" "go.etcd.io/bbolt" "golang.org/x/exp/rand" @@ -38,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) { } makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore, - *db.TransactionExecutor[SQLQueries]) { + *SQLMig6QueriesExecutor[SQLMig6Queries]) { // Create a sql store with a linked account store. testDBStore := NewTestDBWithAccounts(t, clock, acctStore) @@ -48,13 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) { baseDB := store.BaseDB - genericExecutor := db.NewTransactionExecutor( - baseDB, func(tx *sql.Tx) SQLQueries { - return baseDB.WithTx(tx) - }, - ) + queries := sqlcmig6.NewForType(baseDB, baseDB.BackendType) - return store, genericExecutor + return store, NewSQLMig6QueriesExecutor(baseDB, queries) } // assertMigrationResults asserts that the sql store contains the @@ -371,11 +366,11 @@ func TestSessionsStoreMigration(t *testing.T) { var opts sqldb.MigrationTxOptions err = txEx.ExecTx( - ctx, &opts, func(tx SQLQueries) error { + ctx, &opts, func(tx SQLMig6Queries) error { return MigrateSessionStoreToSQL( ctx, kvStore.DB, tx, ) - }, + }, sqldb.NoOpReset, ) require.NoError(t, err) diff --git a/session/sql_store.go b/session/sql_store.go index b1d366fe7..a169dff9b 100644 --- a/session/sql_store.go +++ b/session/sql_store.go @@ -12,8 +12,10 @@ import ( "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/db/sqlcmig6" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb/v2" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" ) @@ -21,6 +23,8 @@ import ( // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with session related tables. type SQLQueries interface { + sqldb.BaseQuerier + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error) @@ -49,14 +53,48 @@ type SQLQueries interface { GetAccount(ctx context.Context, id int64) (sqlc.Account, error) } +// SQLMig6Queries is a subset of the sqlcmig6.Queries interface that can be used to +// interact with session related tables. +type SQLMig6Queries interface { + sqldb.BaseQuerier + + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetSessionByID(ctx context.Context, id int64) (sqlcmig6.Session, error) + GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlcmig6.Session, error) + GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) + GetSessionByAlias(ctx context.Context, legacyID []byte) (sqlcmig6.Session, error) + GetSessionByLocalPublicKey(ctx context.Context, localPublicKey []byte) (sqlcmig6.Session, error) + GetSessionFeatureConfigs(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionFeatureConfig, error) + GetSessionMacaroonCaveats(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonCaveat, error) + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) + GetSessionMacaroonPermissions(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionMacaroonPermission, error) + GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]sqlcmig6.SessionPrivacyFlag, error) + InsertSessionFeatureConfig(ctx context.Context, arg sqlcmig6.InsertSessionFeatureConfigParams) error + SetSessionRevokedAt(ctx context.Context, arg sqlcmig6.SetSessionRevokedAtParams) error + InsertSessionMacaroonCaveat(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonCaveatParams) error + InsertSessionMacaroonPermission(ctx context.Context, arg sqlcmig6.InsertSessionMacaroonPermissionParams) error + InsertSessionPrivacyFlag(ctx context.Context, arg sqlcmig6.InsertSessionPrivacyFlagParams) error + InsertSession(ctx context.Context, arg sqlcmig6.InsertSessionParams) (int64, error) + ListSessions(ctx context.Context) ([]sqlcmig6.Session, error) + ListSessionsByType(ctx context.Context, sessionType int16) ([]sqlcmig6.Session, error) + ListSessionsByState(ctx context.Context, state int16) ([]sqlcmig6.Session, error) + SetSessionRemotePublicKey(ctx context.Context, arg sqlcmig6.SetSessionRemotePublicKeyParams) error + SetSessionGroupID(ctx context.Context, arg sqlcmig6.SetSessionGroupIDParams) error + UpdateSessionState(ctx context.Context, arg sqlcmig6.UpdateSessionStateParams) error + DeleteSessionsWithState(ctx context.Context, state int16) error + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccount(ctx context.Context, id int64) (sqlcmig6.Account, error) +} + var _ Store = (*SQLStore)(nil) -// BatchedSQLQueries is a version of the SQLQueries that's capable of batched -// database operations. +// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx +// interface, allowing for multiple queries to be executed in single SQL +// transaction. type BatchedSQLQueries interface { SQLQueries - db.BatchedTx[SQLQueries] + sqldb.BatchedTx[SQLQueries] } // SQLStore represents a storage backend. @@ -66,19 +104,57 @@ type SQLStore struct { db BatchedSQLQueries // BaseDB represents the underlying database connection. - *db.BaseDB + *sqldb.BaseDB clock clock.Clock } -// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries -// storage backend. -func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - executor := db.NewTransactionExecutor( - sqlDB, func(tx *sql.Tx) SQLQueries { - return sqlDB.WithTx(tx) +type SQLQueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLQueries +} + +func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLQueries { + return queries.WithTx(tx) + }, + ) + return &SQLQueriesExecutor[SQLQueries]{ + TransactionExecutor: executor, + SQLQueries: queries, + } +} + +type SQLMig6QueriesExecutor[T sqldb.BaseQuerier] struct { + *sqldb.TransactionExecutor[T] + + SQLMig6Queries +} + +func NewSQLMig6QueriesExecutor(baseDB *sqldb.BaseDB, + queries *sqlcmig6.Queries) *SQLMig6QueriesExecutor[SQLMig6Queries] { + + executor := sqldb.NewTransactionExecutor( + baseDB, func(tx *sql.Tx) SQLMig6Queries { + return queries.WithTx(tx) }, ) + return &SQLMig6QueriesExecutor[SQLMig6Queries]{ + TransactionExecutor: executor, + SQLMig6Queries: queries, + } +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries, + clock clock.Clock) *SQLStore { + + executor := NewSQLQueriesExecutor(sqlDB, queries) return &SQLStore{ db: executor, @@ -281,7 +357,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type, } return nil - }) + }, sqldb.NoOpReset) if err != nil { mappedSQLErr := db.MapSQLError(err) var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation @@ -325,7 +401,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session, } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -358,7 +434,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) ( } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -417,7 +493,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error { State: int16(dest), }, ) - }) + }, sqldb.NoOpReset) } // DeleteReservedSessions deletes all sessions that are in the StateReserved @@ -428,7 +504,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error { var writeTxOpts db.QueriesTxOptions return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { return db.DeleteSessionsWithState(ctx, int16(StateReserved)) - }) + }, sqldb.NoOpReset) } // GetSessionByLocalPub fetches the session with the given local pub key. @@ -458,7 +534,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } @@ -491,7 +567,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sessions, err } @@ -521,7 +597,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID, RemotePublicKey: remoteKey, }, ) - }) + }, sqldb.NoOpReset) } // getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local @@ -576,7 +652,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) { } return nil - }) + }, sqldb.NoOpReset) return sess, err } @@ -617,7 +693,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) { legacyGroupID, err = IDFromBytes(legacyGroupIDB) return err - }) + }, sqldb.NoOpReset) if err != nil { return ID{}, err } @@ -666,7 +742,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID, } return nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, err } diff --git a/session/test_sql.go b/session/test_sql.go index a83186069..5623c8207 100644 --- a/session/test_sql.go +++ b/session/test_sql.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/lightninglabs/lightning-terminal/accounts" - "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" "github.com/stretchr/testify/require" ) @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock, // createStore is a helper function that creates a new SQLStore and ensure that // it is closed when during the test cleanup. -func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { - store := NewSQLStore(sqlDB, clock) +func createStore(t *testing.T, sqlDB *sqldb.BaseDB, + clock clock.Clock) *SQLStore { + + queries := sqlc.NewForType(sqlDB, sqlDB.BackendType) + + store := NewSQLStore(sqlDB, queries, clock) t.Cleanup(func() { require.NoError(t, store.Close()) }) diff --git a/session/test_sqlite.go b/session/test_sqlite.go index 0ceb0e046..c9dbc5934 100644 --- a/session/test_sqlite.go +++ b/session/test_sqlite.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lightning-terminal/db" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/sqldb/v2" ) // ErrDBClosed is an error that is returned when a database operation is @@ -16,7 +17,11 @@ var ErrDBClosed = errors.New("database is closed") // NewTestDB is a helper function that creates an SQLStore database for testing. func NewTestDB(t *testing.T, clock clock.Clock) Store { - return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock) + return createStore( + t, + sqldb.NewTestSqliteDB(t, db.MakeTestMigrationStreams()).BaseDB, + clock, + ) } // NewTestDBFromPath is a helper function that creates a new SQLStore with a @@ -24,7 +29,9 @@ func NewTestDB(t *testing.T, clock clock.Clock) Store { func NewTestDBFromPath(t *testing.T, dbPath string, clock clock.Clock) Store { - return createStore( - t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + tDb := sqldb.NewTestSqliteDBFromPath( + t, dbPath, db.MakeTestMigrationStreams(), ) + + return createStore(t, tDb.BaseDB, clock) } diff --git a/terminal.go b/terminal.go index 7e4d552c7..5f8fa2efb 100644 --- a/terminal.go +++ b/terminal.go @@ -447,7 +447,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("could not create network directory: %v", err) } - g.stores, err = NewStores(g.cfg, clock.NewDefaultClock()) + g.stores, err = NewStores(ctx, g.cfg, clock.NewDefaultClock()) if err != nil { return fmt.Errorf("could not create stores: %v", err) }