Skip to content

Commit 4a933e1

Browse files
authored
Merge pull request #149 from xataio/pgdumprestore-snapshot-generator
Add pg_dump/pg_restore schema snapshot generator
2 parents c907db8 + fddb3c6 commit 4a933e1

File tree

3 files changed

+266
-1
lines changed

3 files changed

+266
-1
lines changed

pkg/snapshot/generator/postgres/data/pg_snapshot_generator.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func NewSnapshotGenerator(ctx context.Context, cfg *Config, processRow snapshot.
7070
func WithLogger(logger loglib.Logger) Option {
7171
return func(sg *SnapshotGenerator) {
7272
sg.logger = loglib.NewLogger(logger).WithFields(loglib.Fields{
73-
loglib.ModuleField: "postgres_snapshot_generator",
73+
loglib.ModuleField: "postgres_data_snapshot_generator",
7474
})
7575
}
7676
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package pgdumprestore
4+
5+
import (
6+
"context"
7+
"fmt"
8+
9+
pglib "github.com/xataio/pgstream/internal/postgres"
10+
loglib "github.com/xataio/pgstream/pkg/log"
11+
"github.com/xataio/pgstream/pkg/snapshot"
12+
)
13+
14+
// SnapshotGenerator generates postgres schema snapshots using pg_dump and
15+
// pg_restore
16+
type SnapshotGenerator struct {
17+
sourceURL string
18+
targetURL string
19+
pgDumpFn pgdumpFn
20+
pgRestoreFn pgrestoreFn
21+
targetConn pglib.Querier
22+
logger loglib.Logger
23+
}
24+
25+
type Config struct {
26+
SourcePGURL string
27+
TargetPGURL string
28+
}
29+
30+
type (
31+
pgdumpFn func(pglib.PGDumpOptions) ([]byte, error)
32+
pgrestoreFn func(pglib.PGRestoreOptions, []byte) (string, error)
33+
)
34+
35+
type Option func(s *SnapshotGenerator)
36+
37+
const publicSchema = "public"
38+
39+
// NewSnapshotGenerator will return a postgres schema snapshot generator that
40+
// uses pg_dump and pg_restore to sync the schema of two postgres databases
41+
func NewSnapshotGenerator(ctx context.Context, c *Config, opts ...Option) (*SnapshotGenerator, error) {
42+
targetConn, err := pglib.NewConnPool(ctx, c.TargetPGURL)
43+
if err != nil {
44+
return nil, err
45+
}
46+
sg := &SnapshotGenerator{
47+
sourceURL: c.SourcePGURL,
48+
targetURL: c.TargetPGURL,
49+
pgDumpFn: pglib.RunPGDump,
50+
pgRestoreFn: pglib.RunPGRestore,
51+
targetConn: targetConn,
52+
logger: loglib.NewNoopLogger(),
53+
}
54+
55+
for _, opt := range opts {
56+
opt(sg)
57+
}
58+
59+
return sg, nil
60+
}
61+
62+
func WithLogger(logger loglib.Logger) Option {
63+
return func(sg *SnapshotGenerator) {
64+
sg.logger = loglib.NewLogger(logger).WithFields(loglib.Fields{
65+
loglib.ModuleField: "postgres_schema_snapshot_generator",
66+
})
67+
}
68+
}
69+
70+
func (s *SnapshotGenerator) CreateSnapshot(ctx context.Context, ss *snapshot.Snapshot) error {
71+
s.logger.Info("creating schema snapshot", loglib.Fields{"schema": ss.SchemaName, "tables": ss.TableNames})
72+
dump, err := s.pgDumpFn(s.pgdumpOptions(ss))
73+
if err != nil {
74+
return err
75+
}
76+
77+
// if we use table filtering in the pg_dump command, the schema creation
78+
// will not be dumped, so it needs to be created explicitly (except for
79+
// public schema)
80+
if len(ss.TableNames) > 0 && ss.SchemaName != publicSchema {
81+
_, err = s.targetConn.Exec(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", ss.SchemaName))
82+
if err != nil {
83+
return err
84+
}
85+
}
86+
87+
_, err = s.pgRestoreFn(s.pgrestoreOptions(), dump)
88+
if err != nil {
89+
return err
90+
}
91+
92+
return nil
93+
}
94+
95+
func (s *SnapshotGenerator) Close() error {
96+
return s.targetConn.Close(context.Background())
97+
}
98+
99+
func (s *SnapshotGenerator) pgdumpOptions(ss *snapshot.Snapshot) pglib.PGDumpOptions {
100+
opts := pglib.PGDumpOptions{
101+
ConnectionString: s.sourceURL,
102+
Format: "c",
103+
SchemaOnly: true,
104+
Schemas: []string{ss.SchemaName},
105+
}
106+
107+
for _, table := range ss.TableNames {
108+
opts.Tables = append(opts.Tables, ss.SchemaName+"."+table)
109+
}
110+
111+
return opts
112+
}
113+
114+
func (s *SnapshotGenerator) pgrestoreOptions() pglib.PGRestoreOptions {
115+
return pglib.PGRestoreOptions{
116+
ConnectionString: s.targetURL,
117+
SchemaOnly: true,
118+
}
119+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package pgdumprestore
4+
5+
import (
6+
"context"
7+
"errors"
8+
"testing"
9+
10+
"github.com/stretchr/testify/require"
11+
pglib "github.com/xataio/pgstream/internal/postgres"
12+
"github.com/xataio/pgstream/internal/postgres/mocks"
13+
"github.com/xataio/pgstream/pkg/log"
14+
"github.com/xataio/pgstream/pkg/snapshot"
15+
)
16+
17+
func TestSnapshotGenerator_CreateSnapshot(t *testing.T) {
18+
t.Parallel()
19+
20+
testDump := []byte("test dump")
21+
testSchema := "test_schema"
22+
testTable := "test_table"
23+
errTest := errors.New("oh noes")
24+
25+
tests := []struct {
26+
name string
27+
snapshot *snapshot.Snapshot
28+
conn pglib.Querier
29+
pgdumpFn pgdumpFn
30+
pgrestoreFn pgrestoreFn
31+
32+
wantErr error
33+
}{
34+
{
35+
name: "ok",
36+
snapshot: &snapshot.Snapshot{
37+
SchemaName: testSchema,
38+
TableNames: []string{testTable},
39+
},
40+
conn: &mocks.Querier{
41+
ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) {
42+
require.Equal(t, "CREATE SCHEMA IF NOT EXISTS "+testSchema, query)
43+
return pglib.CommandTag{}, nil
44+
},
45+
},
46+
pgdumpFn: func(po pglib.PGDumpOptions) ([]byte, error) {
47+
require.Equal(t, pglib.PGDumpOptions{
48+
ConnectionString: "source-url",
49+
Format: "c",
50+
SchemaOnly: true,
51+
Schemas: []string{testSchema},
52+
Tables: []string{testSchema + "." + testTable},
53+
}, po)
54+
return testDump, nil
55+
},
56+
pgrestoreFn: func(po pglib.PGRestoreOptions, dump []byte) (string, error) {
57+
require.Equal(t, pglib.PGRestoreOptions{
58+
ConnectionString: "target-url",
59+
SchemaOnly: true,
60+
}, po)
61+
require.Equal(t, testDump, dump)
62+
return "", nil
63+
},
64+
65+
wantErr: nil,
66+
},
67+
{
68+
name: "error - performing pgdump",
69+
snapshot: &snapshot.Snapshot{
70+
SchemaName: testSchema,
71+
TableNames: []string{testTable},
72+
},
73+
conn: &mocks.Querier{
74+
ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) {
75+
return pglib.CommandTag{}, errors.New("ExecFn: should not be called")
76+
},
77+
},
78+
pgdumpFn: func(po pglib.PGDumpOptions) ([]byte, error) {
79+
return nil, errTest
80+
},
81+
pgrestoreFn: func(po pglib.PGRestoreOptions, dump []byte) (string, error) {
82+
return "", errors.New("pgrestoreFn: should not be called")
83+
},
84+
85+
wantErr: errTest,
86+
},
87+
{
88+
name: "error - performing pgrestore",
89+
snapshot: &snapshot.Snapshot{
90+
SchemaName: publicSchema,
91+
TableNames: []string{testTable},
92+
},
93+
conn: &mocks.Querier{
94+
ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) {
95+
return pglib.CommandTag{}, errors.New("ExecFn: should not be called")
96+
},
97+
},
98+
pgdumpFn: func(po pglib.PGDumpOptions) ([]byte, error) {
99+
return testDump, nil
100+
},
101+
pgrestoreFn: func(po pglib.PGRestoreOptions, dump []byte) (string, error) {
102+
return "", errTest
103+
},
104+
105+
wantErr: errTest,
106+
},
107+
{
108+
name: "error - creating schema",
109+
snapshot: &snapshot.Snapshot{
110+
SchemaName: testSchema,
111+
TableNames: []string{testTable},
112+
},
113+
conn: &mocks.Querier{
114+
ExecFn: func(ctx context.Context, i uint, query string, args ...any) (pglib.CommandTag, error) {
115+
return pglib.CommandTag{}, errTest
116+
},
117+
},
118+
pgdumpFn: func(po pglib.PGDumpOptions) ([]byte, error) {
119+
return testDump, nil
120+
},
121+
pgrestoreFn: func(po pglib.PGRestoreOptions, dump []byte) (string, error) {
122+
return "", errors.New("pgrestoreFn: should not be called")
123+
},
124+
125+
wantErr: errTest,
126+
},
127+
}
128+
129+
for _, tc := range tests {
130+
t.Run(tc.name, func(t *testing.T) {
131+
t.Parallel()
132+
133+
sg := SnapshotGenerator{
134+
sourceURL: "source-url",
135+
targetURL: "target-url",
136+
targetConn: tc.conn,
137+
pgDumpFn: tc.pgdumpFn,
138+
pgRestoreFn: tc.pgrestoreFn,
139+
logger: log.NewNoopLogger(),
140+
}
141+
142+
err := sg.CreateSnapshot(context.Background(), tc.snapshot)
143+
require.ErrorIs(t, err, tc.wantErr)
144+
})
145+
}
146+
}

0 commit comments

Comments
 (0)