Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Demouth <yuya at demouth.net>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Dmitry Zenovich <dzenovich at gmail.com>
Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Elias <evan at skeema.net>
Expand Down
20 changes: 19 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,10 @@ func (mc *mysqlConn) finish() {

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
return mc.sendSimpleCommandOK(ctx, comPing)
}

func (mc *mysqlConn) sendSimpleCommandOK(ctx context.Context, cmd byte) (err error) {
if mc.closed.Load() {
return driver.ErrBadConn
}
Expand All @@ -513,7 +517,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
defer mc.finish()

handleOk := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
if err = mc.writeCommandPacket(cmd); err != nil {
return mc.markBadConn(err)
}

Expand Down Expand Up @@ -681,6 +685,20 @@ func (mc *mysqlConn) startWatcher() {
}()
}

// Reset resets the server-side session state using COM_RESET_CONNECTION.
// It clears most per-session state (e.g., user variables, prepared statements)
// without re-authenticating.
// Usage hint: call via database/sql.Conn.Raw using a method assertion:
// conn.Raw(func(c any) error {
// if r, ok := c.(interface{ Reset(context.Context) error }); ok {
// return r.Reset(ctx)
// }
// return nil
// })
func (mc *mysqlConn) Reset(ctx context.Context) (err error) {
return mc.sendSimpleCommandOK(ctx, comResetConnection)
}

func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
Expand Down
137 changes: 85 additions & 52 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,65 +129,98 @@ func TestCheckNamedValue(t *testing.T) {
}
}

// TestCleanCancel tests passed context is cancelled at start.
// TestSimpleCommandOKCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()

ctx, cancel := context.WithCancel(context.Background())
cancel()

for range 3 { // Repeat same behavior
err := mc.Ping(ctx)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}

if mc.closed.Load() {
t.Error("expected mc is not closed, closed actually")
}

if mc.watching {
t.Error("expected watching is false, but true")
}
func TestSimpleCommandOKCleanCancel(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(ctx context.Context, mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Ping(ctx) }},
{name: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Reset(ctx) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()

ctx, cancel := context.WithCancel(context.Background())
cancel()

for range 3 { // Repeat same behavior
err := test.funcToCall(ctx, mc)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}

if mc.closed.Load() {
t.Error("expected mc is not closed, closed actually")
}

if mc.watching {
t.Error("expected watching is false, but true")
}
}
})
}
}

func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := mc.Ping(context.Background())

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
func TestSimpleCommandOKMarkBadConnection(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := test.funcToCall(mc)

if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
})
}
}

func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := mc.Ping(context.Background())

if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
func TestSimpleCommandOKErrInvalidConn(t *testing.T) {
for _, test := range []struct {
name string
funcToCall func(mc *mysqlConn) error
} {
{name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }},
{name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }},
} {
test := test
t.Run(test.name, func(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
mc := &mysqlConn{
netConn: nc,
buf: newBuffer(),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := test.funcToCall(mc)

if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
}
})
}
}

Expand Down
3 changes: 3 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ const (
comStmtReset
comSetOption
comStmtFetch
comDaemon
comBinlogDumpGTID
comResetConnection
)

// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
Expand Down
131 changes: 110 additions & 21 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2322,48 +2322,137 @@ func TestRejectReadOnly(t *testing.T) {
}

func TestPing(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.fail("Ping", "Ping", err)
}
})
}

func TestSimpleCommandOK(t *testing.T) {
ctx := context.Background()
for _, test := range []struct{
method string
query string
funcToCall func(ctx context.Context, mc *mysqlConn) error
} {
{method: "Pinger", query: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Ping(ctx)}},
{method: "Conn", query: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Reset(ctx)}},
} {
test := test
t.Run(test.method+"_"+test.query, func(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}
defer conn.Close()

// Check that affectedRows and insertIds are cleared after each call.
conn.Raw(func(conn any) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.QueryContext(ctx, `SELECT 1`, nil)
if err != nil {
dbt.fail("Conn", "QueryContext", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
}
if err := q.Close(); err != nil {
dbt.fail("Rows", "Close", err)
}

// Verify that Ping()/Reset() clears both fields.
for range 2 {
if err := test.funcToCall(ctx, c); err != nil {
// Skip Reset on servers lacking COM_RESET_CONNECTION support.
if test.query == "Reset" {
maybeSkip(t, err, 1047) // ER_UNKNOWN_COM_ERROR
maybeSkip(t, err, 1235) // ER_NOT_SUPPORTED_YET
}
dbt.fail(test.method, test.query, err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad insertIds: got %v, want=%v", got, want)
}
}
return nil
})
})
})
}
}

func TestReset(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}
defer conn.Close()

// Check that affectedRows and insertIds are cleared after each call.
// Verify that COM_RESET_CONNECTION clears session state (e.g., user variables).
conn.Raw(func(conn any) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
_, err = c.ExecContext(ctx, "SET @a := 1", nil)
if err != nil {
dbt.fail("Conn", "Query", err)
dbt.fail("Conn", "ExecContext", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
var rows driver.Rows
rows, err = c.QueryContext(ctx, "SELECT @a", nil)
if err != nil {
dbt.fail("Conn", "QueryContext", err)
}
result := []driver.Value{nil}
err = rows.Next(result)
if err != nil {
dbt.fail("Rows", "Next", err)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
err = rows.Close()
if err != nil {
dbt.fail("Rows", "Close", err)
}
if !(reflect.DeepEqual([]driver.Value{int64(1)}, result) ||
reflect.DeepEqual([]driver.Value{[]byte("1")}, result)) {
dbt.Fatalf("failed to set @a to 1 with SET: got %v, want int64(1) or []byte(\"1\")", result)
}
q.Close()

// Verify that Ping() clears both fields.
for range 2 {
if err := c.Ping(ctx); err != nil {
dbt.fail("Pinger", "Ping", err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
err = c.Reset(ctx)
if err != nil {
// Allow skipping on unsupported COM_RESET_CONNECTION
maybeSkip(t, err, 1047) // ER_UNKNOWN_COM_ERROR
maybeSkip(t, err, 1235) // ER_NOT_SUPPORTED_YET
dbt.fail("Conn", "Reset", err)
}

rows, err = c.QueryContext(ctx, "SELECT @a", nil)
if err != nil {
dbt.fail("Conn", "QueryContext", err)
}
// Seed with a sentinel to ensure Rows.Next overwrites it with nil.
result = []driver.Value{"sentinel-non-nil"}
err = rows.Next(result)
if err != nil {
dbt.fail("Rows", "Next", err)
}
err = rows.Close()
if err != nil {
dbt.fail("Rows", "Close", err)
}
if !reflect.DeepEqual([]driver.Value{nil}, result) {
dbt.Fatalf("Reset did not reset the session (@a is still set): got %v, want=%v", result, []driver.Value{nil})
}

return nil
})
})
Expand Down