Skip to content

Commit 4c6e1ce

Browse files
authored
feat(interceptor): add PrepareContext support (#15)
1 parent d226a94 commit 4c6e1ce

File tree

4 files changed

+65
-18
lines changed

4 files changed

+65
-18
lines changed

README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,14 @@ db, _ := sql.Open("interceptor", "dsn")
9494

9595
db.ExecContext(ctx, "INSERT INTO users VALUES (1, 'John Doe')")
9696
// stderr: INFO ExecContext query="INSERT INTO users VALUES (1, 'John Doe')"
97+
9798
db.QueryContext(ctx, "SELECT id, name FROM users")
9899
// stderr: INFO QueryContext query="SELECT id, name FROM users"
99100
```
100101

101-
> [!note]
102-
> To keep the implementation simple, only `ExecContext` and `QueryContext` callbacks are supported.
103-
> If you need to intercept other database operations, such as `sql.DB.BeginTx`, consider using [ngrok/sqlmw][3] instead.
104-
105102
Integration tests cover the following databases and drivers:
106-
- PostgreSQL with [jackx/pgx][4]
107-
- MySQL with [go-sql-driver/mysql][5]
103+
- PostgreSQL with [jackx/pgx][3]
104+
- MySQL with [go-sql-driver/mysql][4]
108105

109106
## 🚧 TODOs
110107

@@ -114,6 +111,5 @@ Integration tests cover the following databases and drivers:
114111

115112
[1]: https://github.com/golang/go/issues/61637
116113
[2]: https://grpc.io/docs/guides/interceptors
117-
[3]: https://github.com/ngrok/sqlmw
118-
[4]: https://github.com/jackc/pgx
119-
[5]: https://github.com/go-sql-driver/mysql
114+
[3]: https://github.com/jackc/pgx
115+
[4]: https://github.com/go-sql-driver/mysql

interceptor.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ var (
1010
_ driver.DriverContext = Interceptor{}
1111
)
1212

13+
// TODO: document that database/sql falls back to Prepare if the driver returns ErrSkip for Exec/Query.
14+
1315
// Interceptor is a [driver.Driver] wrapper that allows to register callbacks for database queries.
1416
// It must first be registered with [sql.Register] with the same name that is then passed to [sql.Open]:
1517
//
@@ -31,6 +33,11 @@ type Interceptor struct {
3133
// The implementation must call queryer.QueryContext(ctx, query, args) and return the result.
3234
// Optional.
3335
QueryContext func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error)
36+
37+
// PrepareContext is a callback for [sql.DB.PrepareContext].
38+
// The implementation must call preparer.ConnPrepareContext(ctx, query) and return the result.
39+
// Optional.
40+
PrepareContext func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error)
3441
}
3542

3643
// Open implements [driver.Driver].
@@ -54,9 +61,10 @@ func (i Interceptor) OpenConnector(name string) (driver.Connector, error) {
5461
}
5562

5663
var (
57-
_ driver.Conn = wrappedConn{}
58-
_ driver.ExecerContext = wrappedConn{}
59-
_ driver.QueryerContext = wrappedConn{}
64+
_ driver.Conn = wrappedConn{}
65+
_ driver.ExecerContext = wrappedConn{}
66+
_ driver.QueryerContext = wrappedConn{}
67+
_ driver.ConnPrepareContext = wrappedConn{}
6068
)
6169

6270
type wrappedConn struct {
@@ -90,6 +98,18 @@ func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driv
9098

9199
var _ driver.Connector = wrappedConnector{}
92100

101+
// PrepareContext implements [driver.ConnPrepareContext].
102+
func (c wrappedConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
103+
preparer, ok := c.Conn.(driver.ConnPrepareContext)
104+
if !ok {
105+
panic("queries: driver does not implement driver.ConnPrepareContext")
106+
}
107+
if c.interceptor.PrepareContext != nil {
108+
return c.interceptor.PrepareContext(ctx, query, preparer)
109+
}
110+
return preparer.PrepareContext(ctx, query)
111+
}
112+
93113
type wrappedConnector struct {
94114
driver.Connector
95115
interceptor Interceptor

interceptor_test.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,24 @@ import (
1515
func TestInterceptor(t *testing.T) {
1616
ctx := context.Background()
1717

18-
var execIntercepted bool
19-
var queryIntercepted bool
18+
var execCalled bool
19+
var queryCalled bool
20+
var prepareCalled bool
2021

2122
interceptor := queries.Interceptor{
2223
Driver: mockDriver{conn: spyConn{}},
2324
ExecContext: func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error) {
24-
execIntercepted = true
25+
execCalled = true
2526
return execer.ExecContext(ctx, query, args)
2627
},
2728
QueryContext: func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error) {
28-
queryIntercepted = true
29+
queryCalled = true
2930
return queryer.QueryContext(ctx, query, args)
3031
},
32+
PrepareContext: func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error) {
33+
prepareCalled = true
34+
return preparer.PrepareContext(ctx, query)
35+
},
3136
}
3237

3338
driverName := t.Name() + "_interceptor"
@@ -39,11 +44,15 @@ func TestInterceptor(t *testing.T) {
3944

4045
_, err = db.ExecContext(ctx, "")
4146
assert.IsErr[E](t, err, errCalled)
42-
assert.Equal[E](t, execIntercepted, true)
47+
assert.Equal[E](t, execCalled, true)
4348

4449
_, err = db.QueryContext(ctx, "") //nolint:gocritic // sqlQuery: unused result is fine here.
4550
assert.IsErr[E](t, err, errCalled)
46-
assert.Equal[E](t, queryIntercepted, true)
51+
assert.Equal[E](t, queryCalled, true)
52+
53+
_, err = db.PrepareContext(ctx, "")
54+
assert.IsErr[E](t, err, errCalled)
55+
assert.Equal[E](t, prepareCalled, true)
4756
}
4857

4958
func TestInterceptor_passthrough(t *testing.T) {
@@ -65,6 +74,9 @@ func TestInterceptor_passthrough(t *testing.T) {
6574

6675
_, err = db.QueryContext(ctx, "") //nolint:gocritic // sqlQuery: unused result is fine here.
6776
assert.IsErr[E](t, err, errCalled)
77+
78+
_, err = db.PrepareContext(ctx, "")
79+
assert.IsErr[E](t, err, errCalled)
6880
}
6981

7082
func TestInterceptor_unimplemented(t *testing.T) {
@@ -86,6 +98,9 @@ func TestInterceptor_unimplemented(t *testing.T) {
8698

8799
queryFn := func() { _, _ = db.QueryContext(ctx, "") } //nolint:gocritic // sqlQuery: unused result is fine here.
88100
assert.Panics[E](t, queryFn, "queries: driver does not implement driver.QueryerContext")
101+
102+
prepareFn := func() { _, _ = db.PrepareContext(ctx, "") }
103+
assert.Panics[E](t, prepareFn, "queries: driver does not implement driver.ConnPrepareContext")
89104
}
90105

91106
func TestInterceptor_driver(t *testing.T) {
@@ -123,3 +138,7 @@ func (spyConn) ExecContext(context.Context, string, []driver.NamedValue) (driver
123138
func (spyConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
124139
return nil, errCalled
125140
}
141+
142+
func (spyConn) PrepareContext(context.Context, string) (driver.Stmt, error) {
143+
return nil, errCalled
144+
}

tests/integration_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ func TestIntegration(t *testing.T) {
4040
for name, database := range DBs {
4141
var execCalls int
4242
var queryCalls int
43+
var prepareCalls int
4344

4445
interceptor := queries.Interceptor{
4546
Driver: database.driver,
@@ -53,6 +54,11 @@ func TestIntegration(t *testing.T) {
5354
t.Logf("[%s] QueryContext: %s %v", name, query, namedToAny(args))
5455
return queryer.QueryContext(ctx, query, args)
5556
},
57+
PrepareContext: func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error) {
58+
prepareCalls++
59+
t.Logf("[%s] PrepareContext: %s", name, query)
60+
return preparer.PrepareContext(ctx, query)
61+
},
5662
}
5763

5864
driverName := name + "_interceptor"
@@ -111,6 +117,12 @@ func TestIntegration(t *testing.T) {
111117
assert.NoErr[F](t, tx.Commit())
112118
assert.Equal[E](t, execCalls, 2)
113119
assert.Equal[E](t, queryCalls, 5*2)
120+
if name == "mysql" {
121+
// github.com/go-sql-driver/mysql falls back to PrepareContext for queries with arguments.
122+
assert.Equal[E](t, prepareCalls, 1)
123+
} else {
124+
assert.Equal[E](t, prepareCalls, 0)
125+
}
114126
}
115127
}
116128

0 commit comments

Comments
 (0)