Skip to content

Commit 87c9565

Browse files
author
ffffwh
committed
set wait_timeout #1052-2
1 parent f97d764 commit 87c9565

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

driver/mysql/extractor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,10 @@ func (e *Extractor) mysqlDump() (retErr error) {
12841284
}
12851285
e.logger.Debug("got gtid")
12861286
}
1287+
1288+
if originVal, err := sql.GetSetWaitTimeout(tx); err != nil {
1289+
e.logger.Warn("GetSetWaitTimeout. failed. error ignored", "err", err, "originVal", originVal)
1290+
}
12871291
step++
12881292

12891293
// ------

driver/mysql/sql/sqlutils.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727

2828
const (
29-
ConnMaxLifetime = 300 * time.Second // #376
29+
WaitTimeout = 300 * 60 // #376
3030
)
3131

3232
// RowMap represents one row in a result set. Its objective is to allow
@@ -140,7 +140,7 @@ func CreateDB(mysql_uri string) (*gosql.DB, error) {
140140
if err != nil {
141141
return nil, err
142142
}
143-
db.SetConnMaxLifetime(ConnMaxLifetime)
143+
db.SetConnMaxLifetime(WaitTimeout * time.Second)
144144

145145
return db, nil
146146
}
@@ -153,6 +153,11 @@ func CreateConns(ctx context.Context, db *gosql.DB, count int) ([]*Conn, error)
153153
return nil, err
154154
}
155155

156+
originVal, err := GetSetWaitTimeout(conn)
157+
if err != nil {
158+
g.Logger.Warn("CreateConns. GetSetWaitTimeout. failed. error ignored", "err", err, "originVal", originVal)
159+
}
160+
156161
_, err = conn.ExecContext(ctx, "SET @@session.foreign_key_checks = 0")
157162
if err != nil {
158163
return nil, err
@@ -405,3 +410,19 @@ func CloseConns(dbs ...*Conn) error {
405410
}
406411
return nil
407412
}
413+
414+
// GetSetWaitTimeout sets session wait_timeout to `WaitTimeout` unless it is larger.
415+
func GetSetWaitTimeout(db QueryAble) (originVal int, err error) {
416+
row := db.QueryRowContext(context.TODO(), "select @@wait_timeout")
417+
err = row.Scan(&originVal)
418+
if err != nil {
419+
return originVal, err
420+
}
421+
if originVal < WaitTimeout {
422+
_, err = db.ExecContext(context.TODO(), fmt.Sprintf("set wait_timeout = %v", WaitTimeout))
423+
if err != nil {
424+
return originVal, err
425+
}
426+
}
427+
return originVal, nil
428+
}

0 commit comments

Comments
 (0)