Skip to content

Commit 5babf50

Browse files
authored
Merge pull request vitessio#2516 from alainjobart/mysqlconn
Mysqlconn migration
2 parents 6297b03 + 007d0a1 commit 5babf50

16 files changed

+201
-56
lines changed

go/mysql/mysql.go

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ const (
3232
func init() {
3333
// This needs to be called before threads begin to spawn.
3434
C.vt_library_init()
35+
sqldb.Register("libmysqlclient", Connect)
36+
37+
// Comment this out and uncomment call to sqldb.RegisterDefault in
38+
// go/mysqlconn/sqldb_conn.go to make it the default.
3539
sqldb.RegisterDefault(Connect)
3640
}
3741

go/mysqlconn/client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func Connect(ctx context.Context, params *sqldb.ConnParams) (*Conn, error) {
7272
}
7373
if err != nil {
7474
status <- connectResult{
75-
err: sqldb.NewSQLError(CRConnectionError, "", "net.Dial(%v,%v) failed: %v", netProto, addr, err),
75+
err: sqldb.NewSQLError(CRConnHostError, "", "net.Dial(%v,%v) failed: %v", netProto, addr, err),
7676
}
7777
return
7878
}

go/mysqlconn/client_test.go

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mysqlconn
22

33
import (
4+
"fmt"
45
"io/ioutil"
56
"net"
67
"os"
@@ -103,7 +104,58 @@ func TestConnectTimeout(t *testing.T) {
103104
ctx = context.Background()
104105
_, err = Connect(ctx, params)
105106
os.Remove(name)
106-
assertSQLError(t, err, CRConnectionError, SSSignalException, "connection refused")
107+
assertSQLError(t, err, CRConnHostError, SSSignalException, "connection refused")
108+
}
109+
110+
// testKillWithRealDatabase opens a connection, issues a command that
111+
// will sleep for a few seconds, waits a bit for MySQL to start
112+
// executing it, then kills the connection (using another
113+
// connection). We make sure we get the right error code.
114+
func testKillWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
115+
ctx := context.Background()
116+
conn, err := Connect(ctx, params)
117+
if err != nil {
118+
t.Fatal(err)
119+
}
120+
121+
errChan := make(chan error)
122+
go func() {
123+
_, err = conn.ExecuteFetch("select sleep(10) from dual", 1000, false)
124+
errChan <- err
125+
close(errChan)
126+
}()
127+
128+
killConn, err := Connect(ctx, params)
129+
if err != nil {
130+
t.Fatal(err)
131+
}
132+
defer killConn.Close()
133+
134+
if _, err := killConn.ExecuteFetch(fmt.Sprintf("kill %v", conn.ConnectionID), 1000, false); err != nil {
135+
t.Fatalf("Kill(%v) failed: %v", conn.ConnectionID, err)
136+
}
137+
138+
err = <-errChan
139+
assertSQLError(t, err, CRServerLost, SSSignalException, "EOF")
140+
}
141+
142+
// testDupEntryWithRealDatabase tests a duplicate key is properly raised.
143+
func testDupEntryWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
144+
ctx := context.Background()
145+
conn, err := Connect(ctx, params)
146+
if err != nil {
147+
t.Fatal(err)
148+
}
149+
defer conn.Close()
150+
151+
if _, err := conn.ExecuteFetch("create table dup_entry(id int, name int, primary key(id), unique index(name))", 0, false); err != nil {
152+
t.Fatalf("create table failed: %v", err)
153+
}
154+
if _, err := conn.ExecuteFetch("insert into dup_entry(id, name) values(1, 10)", 0, false); err != nil {
155+
t.Fatalf("first insert failed: %v", err)
156+
}
157+
_, err = conn.ExecuteFetch("insert into dup_entry(id, name) values(2, 10)", 0, false)
158+
assertSQLError(t, err, ERDupEntry, SSDupKey, "Duplicate entry")
107159
}
108160

109161
// TestWithRealDatabase runs a real MySQL database, and runs all kinds
@@ -127,6 +179,16 @@ func TestWithRealDatabase(t *testing.T) {
127179
t.Error(err)
128180
}
129181

182+
// Kill tests the query part of the API.
183+
t.Run("Kill", func(t *testing.T) {
184+
testKillWithRealDatabase(t, &params)
185+
})
186+
187+
// DupEntry tests a duplicate key returns the right error.
188+
t.Run("DupEntry", func(t *testing.T) {
189+
testDupEntryWithRealDatabase(t, &params)
190+
})
191+
130192
// Queries tests the query part of the API.
131193
t.Run("Queries", func(t *testing.T) {
132194
testQueriesWithRealDatabase(t, &params)

go/mysqlconn/conn.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ func (c *Conn) readOnePacket() ([]byte, error) {
9999
var header [4]byte
100100

101101
if _, err := io.ReadFull(c.reader, header[:]); err != nil {
102-
return nil, sqldb.NewSQLError(CRServerLost, SSUnknownComError, "io.ReadFull(header size) failed: %v", err)
102+
return nil, sqldb.NewSQLError(CRServerLost, SSSignalException, "io.ReadFull(header size) failed: %v", err)
103103
}
104104

105105
sequence := uint8(header[3])
106106
if sequence != c.sequence {
107-
return nil, sqldb.NewSQLError(CRServerLost, SSUnknownComError, "invalid sequence, expected %v got %v", c.sequence, sequence)
107+
return nil, sqldb.NewSQLError(CRServerLost, SSSignalException, "invalid sequence, expected %v got %v", c.sequence, sequence)
108108
}
109109

110110
c.sequence++
@@ -118,7 +118,7 @@ func (c *Conn) readOnePacket() ([]byte, error) {
118118

119119
data := make([]byte, length)
120120
if _, err := io.ReadFull(c.reader, data); err != nil {
121-
return nil, sqldb.NewSQLError(CRServerLost, SSUnknownComError, "io.ReadFull(packet body of length %v) failed: %v", length, err)
121+
return nil, sqldb.NewSQLError(CRServerLost, SSSignalException, "io.ReadFull(packet body of length %v) failed: %v", length, err)
122122
}
123123
return data, nil
124124
}
@@ -191,16 +191,16 @@ func (c *Conn) writePacket(data []byte) error {
191191
header[2] = byte(packetLength >> 16)
192192
header[3] = c.sequence
193193
if n, err := c.writer.Write(header[:]); err != nil {
194-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(header) failed: %v", err)
194+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(header) failed: %v", err)
195195
} else if n != 4 {
196-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(header) returned a short write: %v < 4", n)
196+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(header) returned a short write: %v < 4", n)
197197
}
198198

199199
// Write the body.
200200
if n, err := c.writer.Write(data[index : index+packetLength]); err != nil {
201-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(packet) failed: %v", err)
201+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(packet) failed: %v", err)
202202
} else if n != packetLength {
203-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(packet) returned a short write: %v < %v", n, packetLength)
203+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(packet) returned a short write: %v < %v", n, packetLength)
204204
}
205205

206206
// Update our state.
@@ -216,9 +216,9 @@ func (c *Conn) writePacket(data []byte) error {
216216
header[2] = 0
217217
header[3] = c.sequence
218218
if n, err := c.writer.Write(header[:]); err != nil {
219-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(empty header) failed: %v", err)
219+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(empty header) failed: %v", err)
220220
} else if n != 4 {
221-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Write(empty header) returned a short write: %v < 4", n)
221+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Write(empty header) returned a short write: %v < 4", n)
222222
}
223223
c.sequence++
224224
}
@@ -230,7 +230,7 @@ func (c *Conn) writePacket(data []byte) error {
230230

231231
func (c *Conn) flush() error {
232232
if err := c.writer.Flush(); err != nil {
233-
return sqldb.NewSQLError(CRServerLost, SSUnknownComError, "Flush() failed: %v", err)
233+
return sqldb.NewSQLError(CRServerLost, SSSignalException, "Flush() failed: %v", err)
234234
}
235235
return nil
236236
}

go/mysqlconn/constants.go

+15
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ const (
172172
// ERUnknownComError is ER_UNKNOWN_COM_ERROR
173173
ERUnknownComError = 1047
174174

175+
// ERDupEntry is ER_DUP_ENTRY
176+
ERDupEntry = 1062
177+
175178
// ERUnknownError is ER_UNKNOWN_ERROR
176179
ERUnknownError = 1105
177180

@@ -186,6 +189,9 @@ const (
186189
// SSSignalException is ER_SIGNAL_EXCEPTION
187190
SSSignalException = "HY000"
188191

192+
// SSDupKey is ER_DUP_KEY
193+
SSDupKey = "23000"
194+
189195
// SSAccessDeniedError is ER_ACCESS_DENIED_ERROR
190196
SSAccessDeniedError = "28000"
191197

@@ -262,3 +268,12 @@ var CharacterSetMap = map[string]uint8{
262268
"cp932": 95,
263269
"eucjpms": 97,
264270
}
271+
272+
// IsNum returns true if a MySQL type is a numeric value.
273+
// It is the same as IS_NUM defined in mysql.h.
274+
//
275+
// FIXME(alainjobart) This needs to use the constants in
276+
// replication/constants.go, so we are using numerical values here.
277+
func IsNum(typ uint8) bool {
278+
return ((typ <= 9 /* MYSQL_TYPE_INT24 */ && typ != 7 /* MYSQL_TYPE_TIMESTAMP */) || typ == 13 /* MYSQL_TYPE_YEAR */ || typ == 246 /* MYSQL_TYPE_NEWDECIMAL */)
279+
}

go/mysqlconn/doc.go

+16
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,20 @@ We should add the following protections for the server:
8888
Should start during initial handshake, maybe have a shorter value during
8989
handshake.
9090
91+
--
92+
NUM_FLAG flag:
93+
94+
It is added by the C client library if the field is numerical.
95+
96+
if (IS_NUM(client_field->type))
97+
client_field->flags|= NUM_FLAG;
98+
99+
This is somewhat useless. Also, that flag overlaps with GROUP_FLAG
100+
(which seems to be used by the server only for temporary tables in
101+
some cases, so it's not a big deal).
102+
103+
But eventually, we probably want to remove it entirely, as it is not
104+
transmitted over the wire. For now, we keep it for backward
105+
compatibility with the C client.
106+
91107
*/

go/mysqlconn/query.go

+11
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
117117
// a Field without that data, so we don't return the flags.
118118
if field.ColumnLength != 0 || field.Charset != 0 {
119119
field.Flags = uint32(flags)
120+
121+
// FIXME(alainjobart): This is something the MySQL
122+
// client library does: If the type is numerical, it
123+
// adds a NUM_FLAG to the flags. We're doing it here
124+
// only to be compatible with the C library. Once
125+
// we're not using that library any more, we'll remove this.
126+
// See doc.go.
127+
if IsNum(t) {
128+
field.Flags |= uint32(querypb.MySqlFlag_NUM_FLAG)
129+
}
120130
}
121131

122132
return nil
@@ -297,6 +307,7 @@ func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (*sqltyp
297307
if !wantfields {
298308
result.Fields = nil
299309
}
310+
result.RowsAffected = uint64(len(result.Rows))
300311
return result, nil
301312
case ErrPacket:
302313
// Error packet.

go/mysqlconn/query_test.go

+31-9
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ func TestQueries(t *testing.T) {
7777
sqltypes.NULL,
7878
},
7979
},
80+
RowsAffected: 2,
8081
})
8182

8283
// Typicall Select with TYPE_AND_NAME.
@@ -179,6 +180,7 @@ func TestQueries(t *testing.T) {
179180
sqltypes.NULL,
180181
},
181182
},
183+
RowsAffected: 2,
182184
})
183185

184186
// Typicall Select with TYPE_AND_NAME.
@@ -198,6 +200,7 @@ func TestQueries(t *testing.T) {
198200
sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
199201
},
200202
},
203+
RowsAffected: 2,
201204
})
202205

203206
// Typicall Select with TYPE_ONLY.
@@ -215,6 +218,7 @@ func TestQueries(t *testing.T) {
215218
sqltypes.MakeTrusted(querypb.Type_INT64, []byte("20")),
216219
},
217220
},
221+
RowsAffected: 2,
218222
})
219223

220224
// Typicall Select with ALL.
@@ -230,7 +234,10 @@ func TestQueries(t *testing.T) {
230234
ColumnLength: 0x80020304,
231235
Charset: 0x1234,
232236
Decimals: 36,
233-
Flags: 16387, // NOT_NULL_FLAG, PRI_KEY_FLAG, PART_KEY_FLAG
237+
Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG |
238+
querypb.MySqlFlag_PRI_KEY_FLAG |
239+
querypb.MySqlFlag_PART_KEY_FLAG |
240+
querypb.MySqlFlag_NUM_FLAG),
234241
},
235242
},
236243
Rows: [][]sqltypes.Value{
@@ -244,6 +251,7 @@ func TestQueries(t *testing.T) {
244251
sqltypes.MakeTrusted(querypb.Type_INT64, []byte("30")),
245252
},
246253
},
254+
RowsAffected: 3,
247255
})
248256
}
249257

@@ -313,8 +321,8 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *
313321
if !reflect.DeepEqual(got, &expected) {
314322
for i, f := range got.Fields {
315323
if !reflect.DeepEqual(f, expected.Fields[i]) {
316-
t.Errorf("Got field(%v) = %v", i, f)
317-
t.Errorf("Expected field(%v) = %v", i, expected.Fields[i])
324+
t.Logf("Got field(%v) = %v", i, f)
325+
t.Logf("Expected field(%v) = %v", i, expected.Fields[i])
318326
}
319327
}
320328
t.Fatalf("ExecuteFetch(wantfields=%v) returned:\n%v\nBut was expecting:\n%v", wantfields, got, expected)
@@ -409,12 +417,16 @@ func testQueriesWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
409417
}
410418

411419
// Try a simple DDL.
412-
if _, err := conn.ExecuteFetch("create table a(id int, name varchar(128), primary key(id))", 0, false); err != nil {
420+
result, err := conn.ExecuteFetch("create table a(id int, name varchar(128), primary key(id))", 0, false)
421+
if err != nil {
413422
t.Fatalf("create table failed: %v", err)
414423
}
424+
if result.RowsAffected != 0 {
425+
t.Errorf("create table returned RowsAffected %v, was expecting 0", result.RowsAffected)
426+
}
415427

416428
// Try a simple insert.
417-
result, err := conn.ExecuteFetch("insert into a(id, name) values(10, 'nice name')", 1000, true)
429+
result, err = conn.ExecuteFetch("insert into a(id, name) values(10, 'nice name')", 1000, true)
418430
if err != nil {
419431
t.Fatalf("insert failed: %v", err)
420432
}
@@ -440,7 +452,8 @@ func testQueriesWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
440452
Charset: CharacterSetBinary,
441453
Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG |
442454
querypb.MySqlFlag_PRI_KEY_FLAG |
443-
querypb.MySqlFlag_PART_KEY_FLAG),
455+
querypb.MySqlFlag_PART_KEY_FLAG |
456+
querypb.MySqlFlag_NUM_FLAG),
444457
},
445458
{
446459
Name: "name",
@@ -459,6 +472,7 @@ func testQueriesWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
459472
sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
460473
},
461474
},
475+
RowsAffected: 1,
462476
}
463477
if !reflect.DeepEqual(result, expectedResult) {
464478
// MySQL 5.7 is adding the NO_DEFAULT_VALUE_FLAG to Flags.
@@ -470,10 +484,13 @@ func testQueriesWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
470484

471485
// Insert a few rows.
472486
for i := 0; i < 100; i++ {
473-
_, err := conn.ExecuteFetch(fmt.Sprintf("insert into a(id, name) values(%v, 'nice name %v')", 1000+i, i), 1000, true)
487+
result, err := conn.ExecuteFetch(fmt.Sprintf("insert into a(id, name) values(%v, 'nice name %v')", 1000+i, i), 1000, true)
474488
if err != nil {
475489
t.Fatalf("ExecuteFetch(%v) failed: %v", i, err)
476490
}
491+
if result.RowsAffected != 1 {
492+
t.Errorf("insert into returned RowsAffected %v, was expecting 1", result.RowsAffected)
493+
}
477494
}
478495

479496
// And use a streaming query to read them back.
@@ -482,9 +499,13 @@ func testQueriesWithRealDatabase(t *testing.T, params *sqldb.ConnParams) {
482499
readRowsUsingStream(t, conn, 101)
483500

484501
// And drop the table.
485-
if _, err := conn.ExecuteFetch("drop table a", 0, false); err != nil {
502+
result, err = conn.ExecuteFetch("drop table a", 0, false)
503+
if err != nil {
486504
t.Fatalf("drop table failed: %v", err)
487505
}
506+
if result.RowsAffected != 0 {
507+
t.Errorf("insert into returned RowsAffected %v, was expecting 0", result.RowsAffected)
508+
}
488509
}
489510

490511
func readRowsUsingStream(t *testing.T, conn *Conn, expectedCount int) {
@@ -506,7 +527,8 @@ func readRowsUsingStream(t *testing.T, conn *Conn, expectedCount int) {
506527
Charset: CharacterSetBinary,
507528
Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG |
508529
querypb.MySqlFlag_PRI_KEY_FLAG |
509-
querypb.MySqlFlag_PART_KEY_FLAG),
530+
querypb.MySqlFlag_PART_KEY_FLAG |
531+
querypb.MySqlFlag_NUM_FLAG),
510532
},
511533
{
512534
Name: "name",

0 commit comments

Comments
 (0)