Skip to content

Commit 674cfd4

Browse files
committed
added support for tidb.
Signed-off-by: Sienna Meridian Satterwhite <sienna.satterwhite@gaimin.io>
1 parent 02f3e5f commit 674cfd4

9 files changed

Lines changed: 657 additions & 5 deletions

File tree

.github/workflows/ci.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@ jobs:
5757
run: |
5858
go test -tags sqlite -race -cover ./...
5959
60+
tidb-tests:
61+
name: TiDB tests - Go v${{ matrix.go-version }}
62+
runs-on: ubuntu-latest
63+
strategy:
64+
matrix:
65+
go-version:
66+
- "1.25"
67+
68+
services:
69+
tidb:
70+
image: pingcap/tidb:latest
71+
ports:
72+
- 4000:4000
73+
74+
steps:
75+
- uses: actions/checkout@v3
76+
- name: Setup Go ${{ matrix.go }}
77+
uses: actions/setup-go@v3
78+
with:
79+
go-version: ${{ matrix.go-version }}
80+
81+
- name: Setup TiDB environment
82+
run: |
83+
mysqldump -u root --port 4000 --version
84+
echo $HOME
85+
echo -e "[mysqldump]\ncolumn-statistics=0" > $HOME/.my.cnf
86+
87+
- name: Build and run soda
88+
env:
89+
SODA_DIALECT: "tidb"
90+
TIDB_PORT: 4000
91+
run: |
92+
go build -v -tags sqlite -o tsoda ./soda
93+
./tsoda drop -e $SODA_DIALECT -p ./testdata/migrations
94+
./tsoda create -e $SODA_DIALECT -p ./testdata/migrations
95+
./tsoda migrate -e $SODA_DIALECT -p ./testdata/migrations
96+
97+
- name: Test
98+
env:
99+
SODA_DIALECT: "tidb"
100+
TIDB_PORT: 4000
101+
run: |
102+
go test -tags sqlite -race -cover ./...
103+
60104
pg-tests:
61105
name: PostgreSQL tests - Go v${{ matrix.go-version }}
62106
runs-on: ubuntu-latest

config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func Test_LoadsConnectionsFromConfig(t *testing.T) {
1313

1414
r.NoError(LoadConfigFile())
1515
if DialectSupported("sqlite3") {
16-
r.Equal(5, len(Connections))
16+
r.Equal(6, len(Connections))
1717
} else {
1818
r.Equal(4, len(Connections))
1919
}

database.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ mysql:
88
options:
99
readTimeout: 5s
1010

11+
tidb:
12+
dialect: "tidb"
13+
database: "pop_test"
14+
host: '{{ envOr "TIDB_HOST" "127.0.0.1" }}'
15+
port: '{{ envOr "TIDB_PORT" "4000" }}'
16+
user: '{{ envOr "TIDB_USER" "root" }}'
17+
password: '{{ envOr "TIDB_PASSWORD" "" }}'
18+
options:
19+
readTimeout: 10s
20+
1121
postgres:
1222
url: '{{ envOr "POSTGRESQL_URL" "postgres://postgres:postgres%23@localhost:5433/pop_test?sslmode=disable" }}'
1323
pool: 25

dialect_tidb.go

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
package pop
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"io"
8+
"os/exec"
9+
"strings"
10+
11+
_mysql "github.com/go-sql-driver/mysql" // Load MySQL Go driver
12+
"github.com/gobuffalo/fizz"
13+
"github.com/gobuffalo/fizz/translators"
14+
"github.com/jmoiron/sqlx"
15+
"github.com/ory/pop/v6/columns"
16+
"github.com/ory/pop/v6/internal/defaults"
17+
"github.com/ory/pop/v6/logging"
18+
)
19+
20+
const nameTiDB = "tidb"
21+
const hostTiDB = "127.0.0.1"
22+
const portTiDB = "4000"
23+
24+
func init() {
25+
AvailableDialects = append(AvailableDialects, nameTiDB)
26+
urlParser[nameTiDB] = urlParserTiDB
27+
finalizer[nameTiDB] = finalizerTiDB
28+
newConnection[nameTiDB] = newTiDB
29+
}
30+
31+
var _ dialect = &tidb{}
32+
33+
type tidb struct {
34+
commonDialect
35+
}
36+
37+
func (m *tidb) Name() string {
38+
return nameTiDB
39+
}
40+
41+
func (m *tidb) DefaultDriver() string {
42+
return nameMySQL
43+
}
44+
45+
func (tidb) Quote(key string) string {
46+
return fmt.Sprintf("`%s`", key)
47+
}
48+
49+
func (m *tidb) Details() *ConnectionDetails {
50+
return m.ConnectionDetails
51+
}
52+
53+
func (m *tidb) URL() string {
54+
cd := m.ConnectionDetails
55+
if cd.URL != "" {
56+
return strings.TrimPrefix(cd.URL, "mysql://")
57+
}
58+
59+
user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
60+
user = strings.Replace(user, ":@", "@", 1)
61+
if user == "@" || strings.HasPrefix(user, ":") {
62+
user = ""
63+
}
64+
65+
addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
66+
// in case of unix domain socket, tricky.
67+
// it is better to check Host is not valid inet address or has '/'.
68+
if cd.Port == "socket" {
69+
addr = fmt.Sprintf("unix(%s)", cd.Host)
70+
}
71+
72+
s := "%s%s/%s?%s"
73+
return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
74+
}
75+
76+
func (m *tidb) urlWithoutDB() string {
77+
cd := m.ConnectionDetails
78+
return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
79+
}
80+
81+
func (m *tidb) MigrationURL() string {
82+
return m.URL()
83+
}
84+
85+
func (m *tidb) Create(c *Connection, model *Model, cols columns.Columns) error {
86+
if err := genericCreate(c, model, cols, m); err != nil {
87+
return fmt.Errorf("tidb create: %w", err)
88+
}
89+
return nil
90+
}
91+
92+
func (m *tidb) Update(c *Connection, model *Model, cols columns.Columns) error {
93+
if err := genericUpdate(c, model, cols, m); err != nil {
94+
return fmt.Errorf("tidb update: %w", err)
95+
}
96+
return nil
97+
}
98+
99+
func (m *tidb) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) {
100+
if n, err := genericUpdateQuery(c, model, cols, m, query, sqlx.QUESTION); err != nil {
101+
return n, fmt.Errorf("tidb update query: %w", err)
102+
} else {
103+
return n, nil
104+
}
105+
}
106+
107+
func (m *tidb) Destroy(c *Connection, model *Model) error {
108+
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField())
109+
_, err := genericExec(c, stmt, model.ID())
110+
if err != nil {
111+
return fmt.Errorf("tidb destroy: %w", err)
112+
}
113+
return nil
114+
}
115+
116+
func (m *tidb) Delete(c *Connection, model *Model, query Query) error {
117+
sqlQuery, args := query.ToSQL(model)
118+
// * MySQL does not support table alias for DELETE syntax until 8.0.
119+
// * Do not generate SQL manually if they may have `WHERE IN`.
120+
// * Spaces are intentionally added to make it easy to see on the log.
121+
sqlQuery = asRegex.ReplaceAllString(sqlQuery, " ")
122+
123+
_, err := genericExec(c, sqlQuery, args...)
124+
return err
125+
}
126+
127+
func (m *tidb) SelectOne(c *Connection, model *Model, query Query) error {
128+
if err := genericSelectOne(c, model, query); err != nil {
129+
return fmt.Errorf("tidb select one: %w", err)
130+
}
131+
return nil
132+
}
133+
134+
func (m *tidb) SelectMany(c *Connection, models *Model, query Query) error {
135+
if err := genericSelectMany(c, models, query); err != nil {
136+
return fmt.Errorf("tidb select many: %w", err)
137+
}
138+
return nil
139+
}
140+
141+
// CreateDB creates a new database, from the given connection credentials
142+
func (m *tidb) CreateDB() error {
143+
deets := m.ConnectionDetails
144+
db, _, err := openPotentiallyInstrumentedConnection(context.Background(), m, m.urlWithoutDB())
145+
if err != nil {
146+
return fmt.Errorf("error creating TiDB database %s: %w", deets.Database, err)
147+
}
148+
defer db.Close()
149+
charset := defaults.String(deets.option("charset"), "utf8mb4")
150+
encoding := defaults.String(deets.option("collation"), "utf8mb4_general_ci")
151+
query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding)
152+
log(logging.SQL, query)
153+
154+
_, err = db.Exec(query)
155+
if err != nil {
156+
return fmt.Errorf("error creating TiDB database %s: %w", deets.Database, err)
157+
}
158+
159+
log(logging.Info, "created database %s", deets.Database)
160+
return nil
161+
}
162+
163+
// DropDB drops an existing database, from the given connection credentials
164+
func (m *tidb) DropDB() error {
165+
deets := m.ConnectionDetails
166+
db, _, err := openPotentiallyInstrumentedConnection(context.Background(), m, m.urlWithoutDB())
167+
if err != nil {
168+
return fmt.Errorf("error dropping TiDB database %s: %w", deets.Database, err)
169+
}
170+
defer db.Close()
171+
query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
172+
log(logging.SQL, query)
173+
174+
_, err = db.Exec(query)
175+
if err != nil {
176+
return fmt.Errorf("error dropping TiDB database %s: %w", deets.Database, err)
177+
}
178+
179+
log(logging.Info, "dropped database %s", deets.Database)
180+
return nil
181+
}
182+
183+
func (m *tidb) TranslateSQL(sql string) string {
184+
return sql
185+
}
186+
187+
func (m *tidb) FizzTranslator() fizz.Translator {
188+
t := translators.NewMySQL(m.URL(), m.Details().Database)
189+
return t
190+
}
191+
192+
func (m *tidb) DumpSchema(w io.Writer) error {
193+
deets := m.Details()
194+
cmd := exec.Command("mysqldump", "--protocol", "TCP", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
195+
if deets.Port == "socket" {
196+
cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
197+
}
198+
return genericDumpSchema(deets, cmd, w)
199+
}
200+
201+
// LoadSchema executes a schema sql file against the configured database.
202+
func (m *tidb) LoadSchema(r io.Reader) error {
203+
return genericLoadSchema(m, r)
204+
}
205+
206+
// TruncateAll truncates all tables for the given connection.
207+
func (m *tidb) TruncateAll(tx *Connection) error {
208+
var stmts []string
209+
err := tx.RawQuery(tidbTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts)
210+
if err != nil {
211+
return err
212+
}
213+
if len(stmts) == 0 {
214+
return nil
215+
}
216+
217+
var qb bytes.Buffer
218+
// #49: Disable foreign keys before truncation
219+
qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
220+
qb.WriteString(strings.Join(stmts, " "))
221+
// #49: Re-enable foreign keys after truncation
222+
qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
223+
224+
return tx.RawQuery(qb.String()).Exec()
225+
}
226+
227+
func (m *tidb) AfterOpen(c *Connection) error {
228+
// ref: ory/kratos#1551
229+
err := c.RawQuery("SET SESSION transaction_isolation = 'REPEATABLE-READ';").Exec()
230+
if err != nil {
231+
return fmt.Errorf("tidb: setting transaction isolation level: %w", err)
232+
}
233+
return nil
234+
}
235+
236+
func newTiDB(deets *ConnectionDetails) (dialect, error) {
237+
cd := &tidb{
238+
commonDialect: commonDialect{ConnectionDetails: deets},
239+
}
240+
return cd, nil
241+
}
242+
243+
func urlParserTiDB(cd *ConnectionDetails) error {
244+
cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
245+
if err != nil {
246+
return fmt.Errorf("the URL '%s' is not supported by MySQL/TiDB driver: %w", cd.URL, err)
247+
}
248+
249+
cd.User = cfg.User
250+
cd.Password = cfg.Passwd
251+
cd.Database = cfg.DBName
252+
253+
// NOTE: use cfg.Params if want to fill options with full parameters
254+
cd.setOption("collation", cfg.Collation)
255+
256+
if cfg.Net == "unix" {
257+
cd.Port = "socket" // trick. see: `URL()`
258+
cd.Host = cfg.Addr
259+
} else {
260+
tmp := strings.Split(cfg.Addr, ":")
261+
cd.Host = tmp[0]
262+
if len(tmp) > 1 {
263+
cd.Port = tmp[1]
264+
}
265+
}
266+
267+
return nil
268+
}
269+
270+
func finalizerTiDB(cd *ConnectionDetails) {
271+
cd.Host = defaults.String(cd.Host, hostTiDB)
272+
cd.Port = defaults.String(cd.Port, portTiDB)
273+
274+
defs := map[string]string{
275+
"readTimeout": "3s",
276+
"collation": "utf8mb4_general_ci",
277+
}
278+
forced := map[string]string{
279+
"parseTime": "true",
280+
"multiStatements": "true",
281+
}
282+
283+
for k, def := range defs {
284+
cd.setOptionWithDefault(k, cd.option(k), def)
285+
}
286+
287+
for k, v := range forced {
288+
// respect user specified options but print warning!
289+
cd.setOptionWithDefault(k, cd.option(k), v)
290+
if cd.option(k) != v { // when user-defined option exists
291+
log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k))
292+
log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k))
293+
} // or override with `cd.Options[k] = v`?
294+
if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
295+
log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
296+
} // or fix user specified url?
297+
}
298+
}
299+
300+
const tidbTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"

0 commit comments

Comments
 (0)