Skip to content

Commit ec666ff

Browse files
author
Phil Bayfield
committed
improve error handling, client cleanup, added transactions, added new test
1 parent 204e624 commit ec666ff

File tree

5 files changed

+208
-9
lines changed

5 files changed

+208
-9
lines changed

error.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ type Errno int
99

1010
const (
1111
CR_UNKNOWN_ERROR Errno = 2000
12+
CR_SOCKET_CREATE_ERROR Errno = 2001
1213
CR_CONNECTION_ERROR Errno = 2002
1314
CR_CONN_HOST_ERROR Errno = 2003
15+
CR_IPSOCK_ERROR Errno = 2004
16+
CR_UNKNOWN_HOST Errno = 2005
1417
CR_SERVER_GONE_ERROR Errno = 2006
1518
CR_SERVER_HANDSHAKE_ERR Errno = 2012
1619
CR_SERVER_LOST Errno = 2013
@@ -24,8 +27,11 @@ type Error string
2427

2528
const (
2629
CR_UNKNOWN_ERROR_STR Error = "Unknown MySQL error"
30+
CR_SOCKET_CREATE_ERROR_STR Error = "Can't create UNIX socket (%d)"
2731
CR_CONNECTION_ERROR_STR Error = "Can't connect to local MySQL server through socket '%s'"
2832
CR_CONN_HOST_ERROR_STR Error = "Can't connect to MySQL server on '%s' (%d)"
33+
CR_IPSOCK_ERROR_STR Error = "Can't create TCP/IP socket (%d)"
34+
CR_UNKNOWN_HOST_STR Error = "Uknown MySQL server host '%s' (%d)"
2935
CR_SERVER_GONE_ERROR_STR Error = "MySQL server has gone away"
3036
CR_SERVER_HANDSHAKE_ERR_STR Error = "Error in server handshake"
3137
CR_SERVER_LOST_STR Error = "Lost connection to MySQL server during query"

mysql.go

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ func (c *Client) Close() (err os.Error) {
163163
c.log(1, "=== Begin close ===")
164164
// Check connection
165165
if !c.checkConn() {
166+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
166167
err = os.NewError("Must be connected to do this")
167168
return
168169
}
@@ -188,6 +189,7 @@ func (c *Client) ChangeDb(dbname string) (err os.Error) {
188189
c.log(1, "=== Begin change db to '%s' ===", dbname)
189190
// Pre-run checks
190191
if !c.checkConn() || c.checkResult() {
192+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
191193
err = os.NewError("Must be connected and not in a result set")
192194
return
193195
}
@@ -210,6 +212,7 @@ func (c *Client) Query(sql string) (err os.Error) {
210212
c.log(1, "=== Begin query '%s' ===", sql)
211213
// Pre-run checks
212214
if !c.checkConn() || c.checkResult() {
215+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
213216
err = os.NewError("Must be connected and not in a result set")
214217
return
215218
}
@@ -232,11 +235,13 @@ func (c *Client) StoreResult() (result *Result, err os.Error) {
232235
c.log(1, "=== Begin store result ===")
233236
// Check result
234237
if !c.checkResult() {
238+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
235239
err = os.NewError("A result is required to do this")
236240
return
237241
}
238242
// Check if result already used/stored
239243
if c.result.mode != RESULT_UNUSED {
244+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
240245
err = os.NewError("This result has already been used or stored")
241246
return
242247
}
@@ -263,11 +268,13 @@ func (c *Client) UseResult() (result *Result, err os.Error) {
263268
c.log(1, "=== Begin use result ===")
264269
// Check result
265270
if !c.checkResult() {
271+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
266272
err = os.NewError("A result is required to do this")
267273
return
268274
}
269275
// Check if result already used/stored
270276
if c.result.mode != RESULT_UNUSED {
277+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
271278
err = os.NewError("This result has already been used or stored")
272279
return
273280
}
@@ -288,6 +295,7 @@ func (c *Client) FreeResult() (err os.Error) {
288295
c.log(1, "=== Begin free result ===")
289296
// Check result
290297
if !c.checkResult() {
298+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
291299
err = os.NewError("A result is required to do this")
292300
return
293301
}
@@ -327,16 +335,18 @@ func (c *Client) MoreResults() bool {
327335
}
328336

329337
// Move to the next available result
330-
func (c *Client) NextResult() (err os.Error) {
338+
func (c *Client) NextResult() (more bool, err os.Error) {
331339
// Log next result
332340
c.log(1, "=== Begin next result ===")
333341
// Pre-run checks
334342
if !c.checkConn() || c.checkResult() {
343+
c.error(CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR)
335344
err = os.NewError("Must be connected and not in a result set")
336345
return
337346
}
338-
if !c.MoreResults() {
339-
err = os.NewError("No more results available")
347+
// Check for more results
348+
more = c.MoreResults()
349+
if !more {
340350
return
341351
}
342352
// Read result from server
@@ -347,6 +357,9 @@ func (c *Client) NextResult() (err os.Error) {
347357

348358
// Set autocommit
349359
func (c *Client) SetAutoCommit(state bool) (err os.Error) {
360+
// Log set autocommit
361+
c.log(1, "=== Begin set autocommit ===")
362+
// Use set autocommit query
350363
sql := "set autocommit="
351364
if state {
352365
sql += "1"
@@ -358,16 +371,25 @@ func (c *Client) SetAutoCommit(state bool) (err os.Error) {
358371

359372
// Start a transaction
360373
func (c *Client) Start() (err os.Error) {
374+
// Log start transaction
375+
c.log(1, "=== Begin start transaction ===")
376+
// Use start transaction query
361377
return c.Query("start transaction")
362378
}
363379

364380
// Commit a transaction
365381
func (c *Client) Commit() (err os.Error) {
382+
// Log commit
383+
c.log(1, "=== Begin commit ===")
384+
// Use commit query
366385
return c.Query("commit")
367386
}
368387

369388
// Rollback a transaction
370389
func (c *Client) Rollback() (err os.Error) {
390+
// Log rollback
391+
c.log(1, "=== Begin rollback ===")
392+
// Use rollback query
371393
return c.Query("rollback")
372394
}
373395

@@ -393,11 +415,26 @@ func (c *Client) Escape(s string) (esc string) {
393415

394416
// Initialise and prepare a new statement
395417
func (c *Client) Prepare(sql string) (stmt *Statement, err os.Error) {
418+
// Initialise a new statement
419+
stmt, err = c.InitStmt()
420+
if err != nil {
421+
return
422+
}
423+
// Prepare statement
424+
err = stmt.Prepare(sql)
396425
return
397426
}
398427

399428
// Initialise a new statment
400-
func (c *Client) StmtInit() (stmt *Statement, err os.Error) {
429+
func (c *Client) InitStmt() (stmt *Statement, err os.Error) {
430+
// Check connection
431+
if !c.checkConn() {
432+
err = os.NewError("Must be connected to do this")
433+
return
434+
}
435+
// Create new statement
436+
stmt = new(Statement)
437+
stmt.c = c
401438
return
402439
}
403440

@@ -478,6 +515,9 @@ func (c *Client) reset() {
478515
c.Errno = 0
479516
c.Error = ""
480517
c.sequence = 0
518+
c.AffectedRows = 0
519+
c.LastInsertId = 0
520+
c.Warnings = 0
481521
c.result = nil
482522
}
483523

mysql_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
package mysql
77

88
import (
9+
"fmt"
910
"os"
11+
"rand"
12+
"strconv"
1013
"testing"
1114
)
1215

@@ -18,6 +21,8 @@ const (
1821
// create user gomysql_test@localhost identified by 'abc123';
1922
// grant all privileges on gomysql_test.* to gomysql_test@localhost;
2023
// grant all privileges on gomysql_test2.* to gomysql_test@localhost;
24+
25+
// Testing settings
2126
TEST_HOST = "localhost"
2227
TEST_PORT = "3306"
2328
TEST_SOCK = "/var/run/mysqld/mysqld.sock"
@@ -28,6 +33,17 @@ const (
2833
TEST_DBNAME2 = "gomysql_test2" // This is a privileged database used to test changedb etc
2934
TEST_DBNAMEUP = "gomysql_test3" // This is an unprivileged database
3035
TEST_DBNAMEBAD = "gomysql_bad" // This is a nonexistant database
36+
37+
// Simple table queries
38+
CREATE_SIMPLE = "CREATE TABLE `simple` (`id` SERIAL NOT NULL, `number` BIGINT NOT NULL, `string` VARCHAR(32) NOT NULL, `text` TEXT NOT NULL, `datetime` DATETIME NOT NULL) ENGINE = InnoDB CHARACTER SET utf8 COLLATE utf8_unicode_ci COMMENT = 'GoMySQL Test Suite Simple Table';"
39+
SELECT_SIMPLE = "SELECT * FROM simple"
40+
INSERT_SIMPLE = "INSERT INTO simple VALUES (null, %d, '%s', '%s', NOW())"
41+
UPDATE_SIMPLE = "UPDATE simple SET `text` = '%s', `datetime` = NOW() WHERE id = %d"
42+
DROP_SIMPLE = "DROP TABLE `simple`"
43+
44+
// All types table queries
45+
CREATE_ALLTYPES = "CREATE TABLE `all_types` (`id` SERIAL NOT NULL, `tiny_int` TINYINT NOT NULL, `tiny_uint` TINYINT UNSIGNED NOT NULL, `small_int` SMALLINT NOT NULL, `small_uint` SMALLINT UNSIGNED NOT NULL, `medium_int` MEDIUMINT NOT NULL, `medium_uint` MEDIUMINT UNSIGNED NOT NULL, `int` INT NOT NULL, `uint` INT UNSIGNED NOT NULL, `big_int` BIGINT NOT NULL, `big_uint` BIGINT UNSIGNED NOT NULL, `decimal` DECIMAL(10,4) NOT NULL, `float` FLOAT NOT NULL, `double` DOUBLE NOT NULL, `real` REAL NOT NULL, `bit` BIT(32) NOT NULL, `boolean` BOOLEAN NOT NULL, `date` DATE NOT NULL, `datetime` DATETIME NOT NULL, `timestamp` TIMESTAMP NOT NULL, `time` TIME NOT NULL, `year` YEAR NOT NULL, `char` CHAR(32) NOT NULL, `varchar` VARCHAR(32) NOT NULL, `tiny_text` TINYTEXT NOT NULL, `text` TEXT NOT NULL, `medium_text` MEDIUMTEXT NOT NULL, `long_text` LONGTEXT NOT NULL, `binary` BINARY(32) NOT NULL, `var_binary` VARBINARY(32) NOT NULL, `tiny_blob` TINYBLOB NOT NULL, `medium_blob` MEDIUMBLOB NOT NULL, `blob` BLOB NOT NULL, `long_blob` LONGBLOB NOT NULL, `enum` ENUM('a','b','c','d','e') NOT NULL, `set` SET('a','b','c','d','e') NOT NULL, `geometry` GEOMETRY NOT NULL) ENGINE = InnoDB CHARACTER SET utf8 COLLATE utf8_unicode_ci COMMENT = 'GoMySQL Test Suite All Types Table'"
46+
DROP_ALLTYPES = "DROP TABLE `all_types`"
3147
)
3248

3349
var (
@@ -104,6 +120,122 @@ func TestDialUnixBadPass(t *testing.T) {
104120
}
105121
}
106122

123+
// Test queries on a simple table (create database, select, insert, update, drop database)
124+
func TestSimple(t *testing.T) {
125+
t.Logf("Running simple table tests")
126+
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
127+
if err != nil {
128+
t.Logf("Error #%d: %s", db.Errno, db.Error)
129+
t.Fail()
130+
}
131+
t.Logf("Create table")
132+
err = db.Query(CREATE_SIMPLE)
133+
if err != nil {
134+
t.Logf("Error #%d: %s", db.Errno, db.Error)
135+
t.Fail()
136+
}
137+
t.Logf("Insert 1000 records")
138+
rowMap := make(map[uint64][]string)
139+
for i := 0; i < 1000; i++ {
140+
num, str1, str2 := rand.Int(), randString(32), randString(128)
141+
err = db.Query(fmt.Sprintf(INSERT_SIMPLE, num, str1, str2))
142+
if err != nil {
143+
t.Logf("Error #%d: %s", db.Errno, db.Error)
144+
t.Fail()
145+
}
146+
row := []string{fmt.Sprintf("%d", num), str1, str2}
147+
rowMap[db.LastInsertId] = row
148+
}
149+
t.Logf("Select inserted data")
150+
err = db.Query(SELECT_SIMPLE)
151+
if err != nil {
152+
t.Logf("Error #%d: %s", db.Errno, db.Error)
153+
t.Fail()
154+
}
155+
t.Logf("Use result")
156+
res, err := db.UseResult()
157+
if err != nil {
158+
t.Logf("Error #%d: %s", db.Errno, db.Error)
159+
t.Fail()
160+
}
161+
t.Logf("Validate inserted data")
162+
for {
163+
row := res.FetchRow()
164+
if row == nil {
165+
break
166+
}
167+
id, _ := strconv.Atoui64(row[0].(string))
168+
num, str1, str2 := row[1].(string), row[2].(string), row[3].(string)
169+
if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 {
170+
t.Logf("String from database doesn't match local string")
171+
t.Fail()
172+
}
173+
}
174+
t.Logf("Free result")
175+
err = res.Free()
176+
if err != nil {
177+
t.Logf("Error #%d: %s", db.Errno, db.Error)
178+
t.Fail()
179+
}
180+
t.Logf("Update some records")
181+
for i := uint64(0); i < 1000; i += 5 {
182+
rowMap[i+1][2] = randString(256)
183+
err = db.Query(fmt.Sprintf(UPDATE_SIMPLE, rowMap[i+1][2], i+1))
184+
if err != nil {
185+
t.Logf("Error #%d: %s", db.Errno, db.Error)
186+
t.Fail()
187+
}
188+
if db.AffectedRows != 1 {
189+
t.Logf("Expected 1 effected row but got %d", db.AffectedRows)
190+
t.Fail()
191+
}
192+
}
193+
t.Logf("Select updated data")
194+
err = db.Query(SELECT_SIMPLE)
195+
if err != nil {
196+
t.Logf("Error #%d: %s", db.Errno, db.Error)
197+
t.Fail()
198+
}
199+
t.Logf("Store result")
200+
res, err = db.StoreResult()
201+
if err != nil {
202+
t.Logf("Error #%d: %s", db.Errno, db.Error)
203+
t.Fail()
204+
}
205+
t.Logf("Validate updated data")
206+
for {
207+
row := res.FetchRow()
208+
if row == nil {
209+
break
210+
}
211+
id, _ := strconv.Atoui64(row[0].(string))
212+
num, str1, str2 := row[1].(string), row[2].(string), row[3].(string)
213+
if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 {
214+
t.Logf("%#v %#v", rowMap[id], row)
215+
t.Logf("String from database doesn't match local string")
216+
t.Fail()
217+
}
218+
}
219+
t.Logf("Free result")
220+
err = res.Free()
221+
if err != nil {
222+
t.Logf("Error #%d: %s", db.Errno, db.Error)
223+
t.Fail()
224+
}
225+
226+
t.Logf("Drop table")
227+
err = db.Query(DROP_SIMPLE)
228+
if err != nil {
229+
t.Logf("Error #%d: %s", db.Errno, db.Error)
230+
t.Fail()
231+
}
232+
err = db.Close()
233+
if err != nil {
234+
t.Logf("Error #%d: %s", db.Errno, db.Error)
235+
t.Fail()
236+
}
237+
}
238+
107239
// Benchmark connect/handshake via TCP
108240
func BenchmarkDialTCP(b *testing.B) {
109241
for i := 0; i < b.N; i++ {
@@ -117,3 +249,14 @@ func BenchmarkDialUnix(b *testing.B) {
117249
DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
118250
}
119251
}
252+
253+
// Create a random string
254+
func randString(strLen int) (randStr string) {
255+
strChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
256+
for i := 0; i < strLen; i++ {
257+
randUint := rand.Uint32()
258+
pos := randUint % uint32(len(strChars))
259+
randStr += string(strChars[pos])
260+
}
261+
return
262+
}

password.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@ import (
1313
const SCRAMBLE_LENGTH_323 = 8
1414

1515
// Random struct, see libmysql/password.c
16-
type rand struct {
16+
type randStruct struct {
1717
maxValue uint32
1818
maxValueDbl float64
1919
seed1 uint32
2020
seed2 uint32
2121
}
2222

2323
// Initialise rand struct, see libmysql/password.c
24-
func randominit(seed1, seed2 uint32) *rand {
25-
return &rand{
24+
func randominit(seed1, seed2 uint32) *randStruct {
25+
return &randStruct{
2626
maxValue: 0x3FFFFFFF,
2727
maxValueDbl: 0x3FFFFFFF,
2828
seed1: seed1 % 0x3FFFFFFF,
@@ -31,7 +31,7 @@ func randominit(seed1, seed2 uint32) *rand {
3131
}
3232

3333
// Generate a random number, see libmysql/password.c
34-
func (r *rand) myRnd() float64 {
34+
func (r *randStruct) myRnd() float64 {
3535
r.seed1 = (r.seed1*3 + r.seed2) % r.maxValue
3636
r.seed2 = (r.seed1 + r.seed2 + 33) % r.maxValue
3737
return float64(r.seed1) / r.maxValueDbl

0 commit comments

Comments
 (0)