Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions internal/tigerfs/db/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/timescale/tigerfs/internal/tigerfs/logging"
"go.uber.org/zap"
)
Expand All @@ -19,14 +18,14 @@ type Constraint struct {

// ValidateConstraints validates column values against table constraints
// Returns an error if any constraints would be violated
func ValidateConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table string, values map[string]interface{}) error {
func ValidateConstraints(ctx context.Context, dbtx DBTX, schema, table string, values map[string]interface{}) error {
logging.Debug("Validating constraints",
zap.String("schema", schema),
zap.String("table", table),
zap.Int("column_count", len(values)))

// Get all columns for the table to check NOT NULL constraints
columns, err := getColumnsForConstraintCheck(ctx, pool, schema, table)
columns, err := getColumnsForConstraintCheck(ctx, dbtx, schema, table)
if err != nil {
return fmt.Errorf("failed to get columns: %w", err)
}
Expand All @@ -48,7 +47,7 @@ func ValidateConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table

// Check UNIQUE constraints
// For each column being updated, check if it has a unique constraint
uniqueConstraints, err := getUniqueConstraints(ctx, pool, schema, table)
uniqueConstraints, err := getUniqueConstraints(ctx, dbtx, schema, table)
if err != nil {
return fmt.Errorf("failed to get unique constraints: %w", err)
}
Expand All @@ -58,7 +57,7 @@ func ValidateConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table
for _, colName := range constraint.Columns {
if value, ok := values[colName]; ok {
// This column is being updated, check for duplicates
if err := checkUniqueConstraint(ctx, pool, schema, table, colName, value); err != nil {
if err := checkUniqueConstraint(ctx, dbtx, schema, table, colName, value); err != nil {
logging.Debug("UNIQUE constraint violation",
zap.String("schema", schema),
zap.String("table", table),
Expand All @@ -81,7 +80,7 @@ func ValidateConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table
}

// getColumnsForConstraintCheck queries column metadata for constraint checking
func getColumnsForConstraintCheck(ctx context.Context, pool *pgxpool.Pool, schema, table string) ([]Column, error) {
func getColumnsForConstraintCheck(ctx context.Context, dbtx DBTX, schema, table string) ([]Column, error) {
query := `
SELECT
column_name,
Expand All @@ -93,7 +92,7 @@ func getColumnsForConstraintCheck(ctx context.Context, pool *pgxpool.Pool, schem
ORDER BY ordinal_position
`

rows, err := pool.Query(ctx, query, schema, table)
rows, err := dbtx.Query(ctx, query, schema, table)
if err != nil {
return nil, fmt.Errorf("failed to query columns: %w", err)
}
Expand Down Expand Up @@ -126,7 +125,7 @@ func getColumnsForConstraintCheck(ctx context.Context, pool *pgxpool.Pool, schem
}

// getUniqueConstraints retrieves UNIQUE constraints for a table
func getUniqueConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table string) ([]Constraint, error) {
func getUniqueConstraints(ctx context.Context, dbtx DBTX, schema, table string) ([]Constraint, error) {
query := `
SELECT
tc.constraint_name,
Expand All @@ -142,7 +141,7 @@ func getUniqueConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table
GROUP BY tc.constraint_name
`

rows, err := pool.Query(ctx, query, schema, table)
rows, err := dbtx.Query(ctx, query, schema, table)
if err != nil {
return nil, fmt.Errorf("failed to query unique constraints: %w", err)
}
Expand All @@ -169,7 +168,7 @@ func getUniqueConstraints(ctx context.Context, pool *pgxpool.Pool, schema, table
}

// checkUniqueConstraint checks if a value violates a unique constraint
func checkUniqueConstraint(ctx context.Context, pool *pgxpool.Pool, schema, table, column string, value interface{}) error {
func checkUniqueConstraint(ctx context.Context, dbtx DBTX, schema, table, column string, value interface{}) error {
// Check if this value already exists in the table
query := fmt.Sprintf(`
SELECT EXISTS(
Expand All @@ -178,7 +177,7 @@ func checkUniqueConstraint(ctx context.Context, pool *pgxpool.Pool, schema, tabl
`, qt(schema, table), qi(column))

var exists bool
err := pool.QueryRow(ctx, query, value).Scan(&exists)
err := dbtx.QueryRow(ctx, query, value).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check unique constraint: %w", err)
}
Expand Down
28 changes: 28 additions & 0 deletions internal/tigerfs/db/dbtx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package db

import (
"context"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)

// DBTX is the common query interface satisfied by *pgxpool.Pool and pgx.Tx.
// Package-level functions accept DBTX so they can operate against either
// a raw pool connection or a transaction with SET LOCAL session variables.
//
// This interface intentionally excludes SendBatch, CopyFrom, and Begin.
// Operations that need those capabilities (bulk import, DDL validation)
// manage their own pgx.Tx directly.
type DBTX interface {
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}

// Compile-time verification that both pool and transaction satisfy DBTX.
var (
_ DBTX = (*pgxpool.Pool)(nil)
_ DBTX = (pgx.Tx)(nil)
)
20 changes: 20 additions & 0 deletions internal/tigerfs/db/dbtx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package db

import (
"testing"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

// TestDBTX_PoolSatisfiesInterface verifies at compile time that
// *pgxpool.Pool satisfies the DBTX interface.
func TestDBTX_PoolSatisfiesInterface(t *testing.T) {
var _ DBTX = (*pgxpool.Pool)(nil)
}

// TestDBTX_TxSatisfiesInterface verifies at compile time that
// pgx.Tx satisfies the DBTX interface.
func TestDBTX_TxSatisfiesInterface(t *testing.T) {
var _ DBTX = (pgx.Tx)(nil)
}
13 changes: 6 additions & 7 deletions internal/tigerfs/db/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"strings"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/timescale/tigerfs/internal/tigerfs/logging"
"go.uber.org/zap"
)
Expand All @@ -23,7 +22,7 @@ import (
// - columns: Column names in database order
// - rows: Row values as [][]interface{}
// - error: Any database error
func GetAllRows(ctx context.Context, pool *pgxpool.Pool, schema, table string, limit int) ([]string, [][]interface{}, error) {
func GetAllRows(ctx context.Context, dbtx DBTX, schema, table string, limit int) ([]string, [][]interface{}, error) {
logging.Debug("Getting all rows for export",
zap.String("schema", schema),
zap.String("table", table),
Expand All @@ -34,7 +33,7 @@ func GetAllRows(ctx context.Context, pool *pgxpool.Pool, schema, table string, l
qt(schema, table),
)

rows, err := pool.Query(ctx, query, limit)
rows, err := dbtx.Query(ctx, query, limit)
if err != nil {
return nil, nil, fmt.Errorf("failed to query rows: %w", err)
}
Expand Down Expand Up @@ -80,7 +79,7 @@ func (c *Client) GetAllRows(ctx context.Context, schema, table string, limit int
// GetFirstNRowsWithData returns the first N rows ordered by primary key ascending.
// Returns full row data, not just primary keys.
// Used for bulk export with .first/N/ pagination.
func GetFirstNRowsWithData(ctx context.Context, pool *pgxpool.Pool, schema, table string, pkColumns []string, limit int) ([]string, [][]interface{}, error) {
func GetFirstNRowsWithData(ctx context.Context, dbtx DBTX, schema, table string, pkColumns []string, limit int) ([]string, [][]interface{}, error) {
logging.Debug("Getting first N rows with data",
zap.String("schema", schema),
zap.String("table", table),
Expand All @@ -92,7 +91,7 @@ func GetFirstNRowsWithData(ctx context.Context, pool *pgxpool.Pool, schema, tabl
qt(schema, table), pkOrderByList(pkColumns, "ASC"),
)

rows, err := pool.Query(ctx, query, limit)
rows, err := dbtx.Query(ctx, query, limit)
if err != nil {
return nil, nil, fmt.Errorf("failed to query first N rows: %w", err)
}
Expand Down Expand Up @@ -138,7 +137,7 @@ func (c *Client) GetFirstNRowsWithData(ctx context.Context, schema, table string
// GetLastNRowsWithData returns the last N rows ordered by primary key descending.
// Returns full row data, not just primary keys.
// Used for bulk export with .last/N/ pagination.
func GetLastNRowsWithData(ctx context.Context, pool *pgxpool.Pool, schema, table string, pkColumns []string, limit int) ([]string, [][]interface{}, error) {
func GetLastNRowsWithData(ctx context.Context, dbtx DBTX, schema, table string, pkColumns []string, limit int) ([]string, [][]interface{}, error) {
logging.Debug("Getting last N rows with data",
zap.String("schema", schema),
zap.String("table", table),
Expand All @@ -150,7 +149,7 @@ func GetLastNRowsWithData(ctx context.Context, pool *pgxpool.Pool, schema, table
qt(schema, table), pkOrderByList(pkColumns, "DESC"),
)

rows, err := pool.Query(ctx, query, limit)
rows, err := dbtx.Query(ctx, query, limit)
if err != nil {
return nil, nil, fmt.Errorf("failed to query last N rows: %w", err)
}
Expand Down
Loading