diff --git a/AUTHORS b/AUTHORS index 05e71df4..514736d1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -42,6 +42,7 @@ Demouth Diego Dupin Dirkjan Bussink DisposaBoy +Dmitry Zenovich Egor Smolyakov Erwan Martin Evan Elias diff --git a/connection.go b/connection.go index 5648e47d..55c09bd2 100644 --- a/connection.go +++ b/connection.go @@ -503,6 +503,10 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + return mc.sendSimpleCommandOK(ctx, comPing) +} + +func (mc *mysqlConn) sendSimpleCommandOK(ctx context.Context, cmd byte) (err error) { if mc.closed.Load() { return driver.ErrBadConn } @@ -513,7 +517,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { defer mc.finish() handleOk := mc.clearResult() - if err = mc.writeCommandPacket(comPing); err != nil { + if err = mc.writeCommandPacket(cmd); err != nil { return mc.markBadConn(err) } @@ -681,6 +685,20 @@ func (mc *mysqlConn) startWatcher() { }() } +// Reset resets the server-side session state using COM_RESET_CONNECTION. +// It clears most per-session state (e.g., user variables, prepared statements) +// without re-authenticating. +// Usage hint: call via database/sql.Conn.Raw using a method assertion: +// conn.Raw(func(c any) error { +// if r, ok := c.(interface{ Reset(context.Context) error }); ok { +// return r.Reset(ctx) +// } +// return nil +// }) +func (mc *mysqlConn) Reset(ctx context.Context) (err error) { + return mc.sendSimpleCommandOK(ctx, comResetConnection) +} + func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = converter{}.ConvertValue(nv.Value) return diff --git a/connection_test.go b/connection_test.go index 440ecbff..6de595e9 100644 --- a/connection_test.go +++ b/connection_test.go @@ -129,65 +129,98 @@ func TestCheckNamedValue(t *testing.T) { } } -// TestCleanCancel tests passed context is cancelled at start. +// TestSimpleCommandOKCleanCancel tests passed context is cancelled at start. // No packet should be sent. Connection should keep current status. -func TestCleanCancel(t *testing.T) { - mc := &mysqlConn{ - closech: make(chan struct{}), - } - mc.startWatcher() - defer mc.cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - for range 3 { // Repeat same behavior - err := mc.Ping(ctx) - if err != context.Canceled { - t.Errorf("expected context.Canceled, got %#v", err) - } - - if mc.closed.Load() { - t.Error("expected mc is not closed, closed actually") - } - - if mc.watching { - t.Error("expected watching is false, but true") - } +func TestSimpleCommandOKCleanCancel(t *testing.T) { + for _, test := range []struct { + name string + funcToCall func(ctx context.Context, mc *mysqlConn) error + } { + {name: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Ping(ctx) }}, + {name: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error { return mc.Reset(ctx) }}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for range 3 { // Repeat same behavior + err := test.funcToCall(ctx, mc) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.Load() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } + }) } } -func TestPingMarkBadConnection(t *testing.T) { - nc := badConnection{err: errors.New("boom")} - mc := &mysqlConn{ - netConn: nc, - buf: newBuffer(), - maxAllowedPacket: defaultMaxAllowedPacket, - closech: make(chan struct{}), - cfg: NewConfig(), - } - - err := mc.Ping(context.Background()) - - if err != driver.ErrBadConn { - t.Errorf("expected driver.ErrBadConn, got %#v", err) +func TestSimpleCommandOKMarkBadConnection(t *testing.T) { + for _, test := range []struct { + name string + funcToCall func(mc *mysqlConn) error + } { + {name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }}, + {name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + mc := &mysqlConn{ + netConn: nc, + buf: newBuffer(), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + cfg: NewConfig(), + } + + err := test.funcToCall(mc) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } + }) } } -func TestPingErrInvalidConn(t *testing.T) { - nc := badConnection{err: errors.New("failed to write"), n: 10} - mc := &mysqlConn{ - netConn: nc, - buf: newBuffer(), - maxAllowedPacket: defaultMaxAllowedPacket, - closech: make(chan struct{}), - cfg: NewConfig(), - } - - err := mc.Ping(context.Background()) - - if err != nc.err { - t.Errorf("expected %#v, got %#v", nc.err, err) +func TestSimpleCommandOKErrInvalidConn(t *testing.T) { + for _, test := range []struct { + name string + funcToCall func(mc *mysqlConn) error + } { + {name: "Ping", funcToCall: func(mc *mysqlConn) error { return mc.Ping(context.Background()) }}, + {name: "Reset", funcToCall: func(mc *mysqlConn) error { return mc.Reset(context.Background()) }}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + mc := &mysqlConn{ + netConn: nc, + buf: newBuffer(), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + cfg: NewConfig(), + } + + err := test.funcToCall(mc) + + if err != nc.err { + t.Errorf("expected %#v, got %#v", nc.err, err) + } + }) } } diff --git a/const.go b/const.go index 6f0cdf30..85a27112 100644 --- a/const.go +++ b/const.go @@ -115,6 +115,9 @@ const ( comStmtReset comSetOption comStmtFetch + comDaemon + comBinlogDumpGTID + comResetConnection ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType diff --git a/driver_test.go b/driver_test.go index ec0f2877..77accbd7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2322,48 +2322,137 @@ func TestRejectReadOnly(t *testing.T) { } func TestPing(t *testing.T) { - ctx := context.Background() runTests(t, dsn, func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { dbt.fail("Ping", "Ping", err) } }) +} + +func TestSimpleCommandOK(t *testing.T) { + ctx := context.Background() + for _, test := range []struct{ + method string + query string + funcToCall func(ctx context.Context, mc *mysqlConn) error + } { + {method: "Pinger", query: "Ping", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Ping(ctx)}}, + {method: "Conn", query: "Reset", funcToCall: func(ctx context.Context, mc *mysqlConn) error {return mc.Reset(ctx)}}, + } { + test := test + t.Run(test.method+"_"+test.query, func(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.fail("db", "Conn", err) + } + defer conn.Close() + + // Check that affectedRows and insertIds are cleared after each call. + conn.Raw(func(conn any) error { + c := conn.(*mysqlConn) + + // Issue a query that sets affectedRows and insertIds. + q, err := c.QueryContext(ctx, `SELECT 1`, nil) + if err != nil { + dbt.fail("Conn", "QueryContext", err) + } + if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad insertIds: got %v, want=%v", got, want) + } + if err := q.Close(); err != nil { + dbt.fail("Rows", "Close", err) + } + // Verify that Ping()/Reset() clears both fields. + for range 2 { + if err := test.funcToCall(ctx, c); err != nil { + // Skip Reset on servers lacking COM_RESET_CONNECTION support. + if test.query == "Reset" { + maybeSkip(t, err, 1047) // ER_UNKNOWN_COM_ERROR + maybeSkip(t, err, 1235) // ER_NOT_SUPPORTED_YET + } + dbt.fail(test.method, test.query, err) + } + if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad insertIds: got %v, want=%v", got, want) + } + } + return nil + }) + }) + }) + } +} + +func TestReset(t *testing.T) { + ctx := context.Background() runTests(t, dsn, func(dbt *DBTest) { conn, err := dbt.db.Conn(ctx) if err != nil { dbt.fail("db", "Conn", err) } + defer conn.Close() - // Check that affectedRows and insertIds are cleared after each call. + // Verify that COM_RESET_CONNECTION clears session state (e.g., user variables). conn.Raw(func(conn any) error { c := conn.(*mysqlConn) - // Issue a query that sets affectedRows and insertIds. - q, err := c.Query(`SELECT 1`, nil) + _, err = c.ExecContext(ctx, "SET @a := 1", nil) if err != nil { - dbt.fail("Conn", "Query", err) + dbt.fail("Conn", "ExecContext", err) } - if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) { - dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want) + var rows driver.Rows + rows, err = c.QueryContext(ctx, "SELECT @a", nil) + if err != nil { + dbt.fail("Conn", "QueryContext", err) + } + result := []driver.Value{nil} + err = rows.Next(result) + if err != nil { + dbt.fail("Rows", "Next", err) } - if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) { - dbt.Fatalf("bad insertIds: got %v, want=%v", got, want) + err = rows.Close() + if err != nil { + dbt.fail("Rows", "Close", err) + } + if !(reflect.DeepEqual([]driver.Value{int64(1)}, result) || + reflect.DeepEqual([]driver.Value{[]byte("1")}, result)) { + dbt.Fatalf("failed to set @a to 1 with SET: got %v, want int64(1) or []byte(\"1\")", result) } - q.Close() - // Verify that Ping() clears both fields. - for range 2 { - if err := c.Ping(ctx); err != nil { - dbt.fail("Pinger", "Ping", err) - } - if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) { - t.Errorf("bad affectedRows: got %v, want=%v", got, want) - } - if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) { - t.Errorf("bad affectedRows: got %v, want=%v", got, want) - } + err = c.Reset(ctx) + if err != nil { + // Allow skipping on unsupported COM_RESET_CONNECTION + maybeSkip(t, err, 1047) // ER_UNKNOWN_COM_ERROR + maybeSkip(t, err, 1235) // ER_NOT_SUPPORTED_YET + dbt.fail("Conn", "Reset", err) + } + + rows, err = c.QueryContext(ctx, "SELECT @a", nil) + if err != nil { + dbt.fail("Conn", "QueryContext", err) + } + // Seed with a sentinel to ensure Rows.Next overwrites it with nil. + result = []driver.Value{"sentinel-non-nil"} + err = rows.Next(result) + if err != nil { + dbt.fail("Rows", "Next", err) + } + err = rows.Close() + if err != nil { + dbt.fail("Rows", "Close", err) } + if !reflect.DeepEqual([]driver.Value{nil}, result) { + dbt.Fatalf("Reset did not reset the session (@a is still set): got %v, want=%v", result, []driver.Value{nil}) + } + return nil }) })