Skip to content

Commit f97d764

Browse files
author
ffffwh
committed
refactoring QueryAble
1 parent 2e38652 commit f97d764

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

driver/mysql/applier_gtid_executed.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mysql
22

33
import (
4+
"context"
45
gosql "database/sql"
56
"fmt"
67
"strconv"
@@ -323,7 +324,7 @@ func SelectAllGtidExecuted(db sql.QueryAble, jid string, gtidSet *mysql.MysqlGTI
323324
query := fmt.Sprintf(`SELECT source_uuid,gtid,gtid_set FROM %v.%v where job_name=?`,
324325
g.DtleSchemaName, g.GtidExecutedTableV4)
325326

326-
rows, err := db.Query(query, jid)
327+
rows, err := db.QueryContext(context.TODO(), query, jid)
327328
if err != nil {
328329
return nil, err
329330
}

driver/mysql/base/utils.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package base
88

99
import (
1010
"bytes"
11+
"context"
1112
gosql "database/sql"
1213
"fmt"
1314
"github.com/hashicorp/go-hclog"
@@ -120,7 +121,7 @@ func GetSomeSysVars(db usql.QueryAble, logger g.LoggerType) (r struct {
120121
NetWriteTimeout int
121122
}) {
122123
query := `select @@version, @@time_zone, @@lower_case_table_names, @@net_write_timeout`
123-
r.Err = db.QueryRow(query).Scan(&r.Version, &r.TimeZome, &r.LowerCaseTableNames, &r.NetWriteTimeout)
124+
r.Err = db.QueryRowContext(context.TODO(), query).Scan(&r.Version, &r.TimeZome, &r.LowerCaseTableNames, &r.NetWriteTimeout)
124125
if r.Err != nil {
125126
return
126127
}
@@ -141,7 +142,7 @@ func GetSomeSysVars(db usql.QueryAble, logger g.LoggerType) (r struct {
141142
func ShowCreateTable(db usql.QueryAble, databaseName, tableName string) (statement string, err error) {
142143
var dummy, createTableStatement string
143144
query := fmt.Sprintf(`show create table %s.%s`, umconf.EscapeName(databaseName), umconf.EscapeName(tableName))
144-
err = db.QueryRow(query).Scan(&dummy, &createTableStatement)
145+
err = db.QueryRowContext(context.TODO(), query).Scan(&dummy, &createTableStatement)
145146
return createTableStatement, err
146147
}
147148

driver/mysql/dumper.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (d *dumper) getChunkData() (nRows int64, err error) {
167167

168168
if d.doChecksum != 0 {
169169
if d.doChecksum == 2 || (d.doChecksum == 1 && d.Iteration == 0) {
170-
row := d.db.QueryRow(fmt.Sprintf("checksum table %v.%v", d.TableSchema, d.TableName))
170+
row := d.db.QueryRowContext(d.Ctx, fmt.Sprintf("checksum table %v.%v", d.TableSchema, d.TableName))
171171
var table string
172172
var cs int64
173173
err := row.Scan(&table, &cs)
@@ -181,7 +181,7 @@ func (d *dumper) getChunkData() (nRows int64, err error) {
181181

182182
// this must be increased after building query
183183
d.Iteration += 1
184-
rows, err := d.db.Query(query)
184+
rows, err := d.db.QueryContext(context.TODO(), query)
185185
if err != nil {
186186
d.Logger.Error("error at select chunk", "query", query)
187187
return 0, errors.Wrap(err, "select chunk")
@@ -249,7 +249,7 @@ func (d *dumper) getChunkData() (nRows int64, err error) {
249249
case <-timer.C:
250250
d.Logger.Debug("resultsChannel full. waiting and ping conn")
251251
var dummy int
252-
errPing := d.db.QueryRow("select 1").Scan(&dummy)
252+
errPing := d.db.QueryRowContext(d.Ctx, "select 1").Scan(&dummy)
253253
if errPing != nil {
254254
d.Logger.Debug("ping query row got error.", "err", errPing)
255255
}

driver/mysql/extractor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ func (e *Extractor) CountTableRows(db sql.QueryAble, table *common.Table) (int64
863863
// It only requires select privilege on target table to select its information_schema item.
864864
query = fmt.Sprintf(`select table_rows from information_schema.tables where table_schema = ? and table_name = ?`)
865865
var rowsEstimate int64
866-
err := db.QueryRow(query, table.TableSchema, table.TableName).Scan(&rowsEstimate)
866+
err := db.QueryRowContext(e.ctx, query, table.TableSchema, table.TableName).Scan(&rowsEstimate)
867867
if err != nil {
868868
e.logger.Error("error when getting estimated row number (using information_schema)", "err", err,
869869
"schema", table.TableSchema, "table", table.TableName)

driver/mysql/sql/sqlutils.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func QueryRowsMap(db QueryAble, query string, on_row func(RowMap) error, args ..
225225
}
226226
}()
227227

228-
rows, err := db.Query(query, args...)
228+
rows, err := db.QueryContext(context.TODO(), query, args...)
229229
defer rows.Close()
230230
if err != nil && err != gosql.ErrNoRows {
231231
return err
@@ -236,23 +236,21 @@ func QueryRowsMap(db QueryAble, query string, on_row func(RowMap) error, args ..
236236

237237
// from https://github.com/golang/go/issues/14468
238238
type QueryAble interface {
239-
Exec(query string, args ...interface{}) (gosql.Result, error)
240-
Prepare(query string) (*gosql.Stmt, error)
241-
Query(query string, args ...interface{}) (*gosql.Rows, error)
242-
QueryRow(query string, args ...interface{}) *gosql.Row
239+
ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error)
240+
QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error)
243241
QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row
244242
}
245243

246244
func GetServerUUID(db QueryAble) (result string, err error) {
247-
err = db.QueryRow(`SELECT @@SERVER_UUID /*dtle*/`).Scan(&result)
245+
err = db.QueryRowContext(context.TODO(), `SELECT @@SERVER_UUID /*dtle*/`).Scan(&result)
248246
if err != nil {
249247
return "", err
250248
}
251249
return result, nil
252250
}
253251

254252
func ShowMasterStatus(db QueryAble) *gosql.Row {
255-
return db.QueryRow("show master status /*dtle*/")
253+
return db.QueryRowContext(context.TODO(), "show master status /*dtle*/")
256254
}
257255

258256
// queryResultData returns a raw array of rows for a given query, optionally reading and returning column names
@@ -296,7 +294,7 @@ func ShowDatabases(db QueryAble) ([]string, error) {
296294
dbs := make([]string, 0)
297295

298296
// Get table list
299-
rows, err := db.Query("SHOW DATABASES")
297+
rows, err := db.QueryContext(context.TODO(), "SHOW DATABASES")
300298
if err != nil {
301299
return dbs, err
302300
}
@@ -341,7 +339,7 @@ func ShowTables(db QueryAble, dbName string, showType bool) (tables []*common.Ta
341339
query = fmt.Sprintf("SHOW TABLES IN %s", escapedDbName)
342340
}
343341
g.Logger.Debug("ShowTables", "query", query)
344-
rows, err := db.Query(query)
342+
rows, err := db.QueryContext(context.TODO(), query)
345343
if err != nil {
346344
return tables, err
347345
}

0 commit comments

Comments
 (0)