Skip to content

Commit cc87c87

Browse files
authored
🐛 transfer TEXT default (#152)
1 parent 53dfc1f commit cc87c87

File tree

3 files changed

+272
-2
lines changed

3 files changed

+272
-2
lines changed

src/sqlite3_to_mysql/transporter.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from .mysql_utils import (
4747
MYSQL_BLOB_COLUMN_TYPES,
4848
MYSQL_COLUMN_TYPES,
49-
MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT,
5049
MYSQL_INSERT_METHOD,
5150
MYSQL_TEXT_COLUMN_TYPES,
5251
MYSQL_TEXT_COLUMN_TYPES_WITH_JSON,
@@ -109,6 +108,8 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]):
109108

110109
self._mysql_port = kwargs.get("mysql_port", 3306) or 3306
111110

111+
self._is_mariadb = False
112+
112113
if kwargs.get("mysql_socket") is not None:
113114
if not os.path.exists(str(kwargs.get("mysql_socket"))):
114115
raise FileNotFoundError("MySQL socket does not exist")
@@ -231,6 +232,7 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]):
231232
raise
232233

233234
self._mysql_version = self._get_mysql_version()
235+
self._is_mariadb = "-mariadb" in self._mysql_version.lower()
234236
self._mysql_json_support = check_mysql_json_support(self._mysql_version)
235237
self._mysql_fulltext_support = check_mysql_fulltext_support(self._mysql_version)
236238
self._allow_expr_defaults = check_mysql_expression_defaults_support(self._mysql_version)
@@ -329,6 +331,69 @@ def _create_database(self) -> None:
329331
def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]:
330332
return cls.COLUMN_PATTERN.match(column_type.strip())
331333

334+
@classmethod
335+
def _base_mysql_column_type(cls, column_type: str) -> str:
336+
stripped: str = column_type.strip()
337+
if not stripped:
338+
return ""
339+
match = cls._valid_column_type(stripped)
340+
if match:
341+
return match.group(0).strip().upper()
342+
return stripped.split("(", 1)[0].strip().upper()
343+
344+
def _column_type_supports_default(self, base_type: str, allow_expr_defaults: bool) -> bool:
345+
normalized: str = base_type.upper()
346+
if not normalized:
347+
return True
348+
if normalized == "GEOMETRY":
349+
return False
350+
if normalized in MYSQL_BLOB_COLUMN_TYPES:
351+
return False
352+
if normalized in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON:
353+
return allow_expr_defaults
354+
return True
355+
356+
@staticmethod
357+
def _parse_sql_expression(value: str) -> t.Optional[exp.Expression]:
358+
stripped: str = value.strip()
359+
if not stripped:
360+
return None
361+
for dialect in ("mysql", "sqlite"):
362+
try:
363+
return sqlglot.parse_one(stripped, read=dialect)
364+
except sqlglot_errors.ParseError:
365+
continue
366+
return None
367+
368+
def _format_textual_default(
369+
self,
370+
default_sql: str,
371+
allow_expr_defaults: bool,
372+
is_mariadb: bool,
373+
) -> str:
374+
"""Normalise textual DEFAULT expressions and wrap for MySQL via sqlglot."""
375+
stripped: str = default_sql.strip()
376+
if not stripped or stripped.upper() == "NULL":
377+
return stripped
378+
if not allow_expr_defaults:
379+
return stripped
380+
381+
expr: t.Optional[exp.Expression] = self._parse_sql_expression(stripped)
382+
if expr is None:
383+
if is_mariadb or stripped.startswith("("):
384+
return stripped
385+
return f"({stripped})"
386+
387+
formatted: str = expr.sql(dialect="mysql")
388+
if is_mariadb:
389+
return formatted
390+
391+
if isinstance(expr, exp.Paren):
392+
return formatted
393+
394+
wrapped = exp.Paren(this=expr.copy())
395+
return wrapped.sql(dialect="mysql")
396+
332397
def _translate_type_from_sqlite_to_mysql(self, column_type: str) -> str:
333398
normalized: t.Optional[str] = self._normalize_sqlite_column_type(column_type)
334399
if normalized and normalized.upper() != column_type.upper():
@@ -804,16 +869,25 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa
804869
column["pk"] > 0 and column_type.startswith(("INT", "BIGINT")) and not compound_primary_key
805870
)
806871

872+
allow_expr_defaults: bool = getattr(self, "_allow_expr_defaults", False)
873+
is_mariadb: bool = getattr(self, "_is_mariadb", False)
874+
base_type: str = self._base_mysql_column_type(column_type)
875+
807876
# Build DEFAULT clause safely (preserve falsy defaults like 0/'')
808877
default_clause: str = ""
809878
if (
810879
not skip_default
811880
and column["dflt_value"] is not None
812-
and column_type not in MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT
881+
and self._column_type_supports_default(base_type, allow_expr_defaults)
813882
and not auto_increment
814883
):
815884
td: str = self._translate_default_for_mysql(column_type, str(column["dflt_value"]))
816885
if td != "":
886+
stripped_td: str = td.strip()
887+
if base_type in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON and stripped_td.upper() != "NULL":
888+
td = self._format_textual_default(stripped_td, allow_expr_defaults, is_mariadb)
889+
else:
890+
td = stripped_td
817891
default_clause = "DEFAULT " + td
818892
sql += " `{name}` {type} {notnull} {default} {auto_increment}, ".format(
819893
name=mysql_safe_name,

src/sqlite3_to_mysql/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class SQLite3toMySQLAttributes:
8787
_mysql: MySQLConnection
8888
_mysql_cur: MySQLCursor
8989
_mysql_version: str
90+
_is_mariadb: bool
9091
_mysql_json_support: bool
9192
_mysql_fulltext_support: bool
9293
_allow_expr_defaults: bool

tests/unit/sqlite3_to_mysql_test.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,201 @@ def test_create_table_invalid_default_retries_without_defaults(self, mocker: Moc
10261026
assert "DEFAULT CURRENT_TIMESTAMP" not in retry_sql
10271027
instance._logger.warning.assert_called_once()
10281028

1029+
def test_create_table_text_default_mariadb(self, mocker: MockerFixture) -> None:
1030+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1031+
instance._sqlite_table_xinfo_support = False
1032+
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
1033+
instance._mysql_charset = "utf8mb4"
1034+
instance._mysql_collation = "utf8mb4_unicode_ci"
1035+
instance._logger = mocker.MagicMock()
1036+
instance._allow_expr_defaults = True
1037+
instance._is_mariadb = True
1038+
1039+
rows = [
1040+
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0},
1041+
]
1042+
1043+
sqlite_cursor = mocker.MagicMock()
1044+
sqlite_cursor.fetchall.return_value = rows
1045+
instance._sqlite_cur = sqlite_cursor
1046+
1047+
instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")
1048+
1049+
mysql_cursor = mocker.MagicMock()
1050+
instance._mysql_cur = mysql_cursor
1051+
instance._mysql = mocker.MagicMock()
1052+
1053+
instance._create_table("demo")
1054+
1055+
executed_sql = mysql_cursor.execute.call_args[0][0]
1056+
assert "DEFAULT '[]'" in executed_sql
1057+
assert "DEFAULT ('[]')" not in executed_sql
1058+
1059+
def test_create_table_text_default_mysql_expression(self, mocker: MockerFixture) -> None:
1060+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1061+
instance._sqlite_table_xinfo_support = False
1062+
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
1063+
instance._mysql_charset = "utf8mb4"
1064+
instance._mysql_collation = "utf8mb4_unicode_ci"
1065+
instance._logger = mocker.MagicMock()
1066+
instance._allow_expr_defaults = True
1067+
instance._is_mariadb = False
1068+
1069+
rows = [
1070+
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0},
1071+
]
1072+
1073+
sqlite_cursor = mocker.MagicMock()
1074+
sqlite_cursor.fetchall.return_value = rows
1075+
instance._sqlite_cur = sqlite_cursor
1076+
1077+
instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")
1078+
1079+
mysql_cursor = mocker.MagicMock()
1080+
instance._mysql_cur = mysql_cursor
1081+
instance._mysql = mocker.MagicMock()
1082+
1083+
instance._create_table("demo")
1084+
1085+
executed_sql = mysql_cursor.execute.call_args[0][0]
1086+
assert "DEFAULT ('[]')" in executed_sql
1087+
1088+
def test_create_table_text_default_mysql_function_expression(self, mocker: MockerFixture) -> None:
1089+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1090+
instance._sqlite_table_xinfo_support = False
1091+
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
1092+
instance._mysql_charset = "utf8mb4"
1093+
instance._mysql_collation = "utf8mb4_unicode_ci"
1094+
instance._logger = mocker.MagicMock()
1095+
instance._allow_expr_defaults = True
1096+
instance._is_mariadb = False
1097+
1098+
rows = [
1099+
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "json_array()", "pk": 0},
1100+
]
1101+
1102+
sqlite_cursor = mocker.MagicMock()
1103+
sqlite_cursor.fetchall.return_value = rows
1104+
instance._sqlite_cur = sqlite_cursor
1105+
1106+
instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")
1107+
instance._translate_default_for_mysql = mocker.MagicMock(return_value="JSON_ARRAY()")
1108+
1109+
mysql_cursor = mocker.MagicMock()
1110+
instance._mysql_cur = mysql_cursor
1111+
instance._mysql = mocker.MagicMock()
1112+
1113+
instance._create_table("demo")
1114+
1115+
executed_sql = mysql_cursor.execute.call_args[0][0]
1116+
assert "DEFAULT (JSON_ARRAY())" in executed_sql
1117+
1118+
def test_parse_sql_expression_falls_back_to_sqlite(self, mocker: MockerFixture) -> None:
1119+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1120+
sqlite_expr = exp.Literal.string("ok")
1121+
parse_mock = mocker.patch(
1122+
"sqlite3_to_mysql.transporter.sqlglot.parse_one",
1123+
side_effect=[sqlglot_errors.ParseError("mysql"), sqlite_expr],
1124+
)
1125+
1126+
result = instance._parse_sql_expression("value")
1127+
1128+
assert result is sqlite_expr
1129+
assert parse_mock.call_args_list[0].kwargs["read"] == "mysql"
1130+
assert parse_mock.call_args_list[1].kwargs["read"] == "sqlite"
1131+
1132+
def test_parse_sql_expression_returns_none_when_unparseable(self, mocker: MockerFixture) -> None:
1133+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1134+
parse_mock = mocker.patch(
1135+
"sqlite3_to_mysql.transporter.sqlglot.parse_one",
1136+
side_effect=[
1137+
sqlglot_errors.ParseError("mysql"),
1138+
sqlglot_errors.ParseError("sqlite"),
1139+
],
1140+
)
1141+
1142+
result = instance._parse_sql_expression("value")
1143+
1144+
assert result is None
1145+
assert parse_mock.call_count == 2
1146+
1147+
def test_format_textual_default_wraps_when_unparseable_mysql(self, mocker: MockerFixture) -> None:
1148+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1149+
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)
1150+
1151+
result = instance._format_textual_default("raw_json()", True, False)
1152+
1153+
assert result == "(raw_json())"
1154+
1155+
def test_format_textual_default_mariadb_uses_literal_output(self, mocker: MockerFixture) -> None:
1156+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1157+
literal_expr = exp.Literal.string("[]")
1158+
mocker.patch.object(instance, "_parse_sql_expression", return_value=literal_expr)
1159+
1160+
result = instance._format_textual_default("'[]'", True, True)
1161+
1162+
assert result == "'[]'"
1163+
1164+
def test_format_textual_default_preserves_existing_parens(self, mocker: MockerFixture) -> None:
1165+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1166+
paren_expr = exp.Paren(this=exp.Literal.string("[]"))
1167+
mocker.patch.object(instance, "_parse_sql_expression", return_value=paren_expr)
1168+
1169+
result = instance._format_textual_default("('[]')", True, False)
1170+
1171+
assert result == "('[]')"
1172+
1173+
def test_format_textual_default_respects_disabled_expression_defaults(self) -> None:
1174+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1175+
1176+
result = instance._format_textual_default("'[]'", False, False)
1177+
1178+
assert result == "'[]'"
1179+
1180+
def test_base_mysql_column_type_handles_whitespace_and_unknown(self) -> None:
1181+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1182+
1183+
assert instance._base_mysql_column_type(" TEXT(255) ") == "TEXT"
1184+
assert instance._base_mysql_column_type("custom_type") == "CUSTOM_TYPE"
1185+
assert instance._base_mysql_column_type("(TEXT)") == ""
1186+
assert instance._base_mysql_column_type(" ") == ""
1187+
1188+
def test_column_type_supports_default_branches(self) -> None:
1189+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1190+
1191+
assert not instance._column_type_supports_default("GEOMETRY", True)
1192+
assert not instance._column_type_supports_default("BLOB", True)
1193+
assert not instance._column_type_supports_default("TEXT", False)
1194+
assert instance._column_type_supports_default("", True)
1195+
assert instance._column_type_supports_default("VARCHAR", False)
1196+
1197+
def test_parse_sql_expression_returns_none_for_blank(self) -> None:
1198+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1199+
1200+
assert instance._parse_sql_expression(" ") is None
1201+
1202+
def test_format_textual_default_handles_blank_and_null(self) -> None:
1203+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1204+
1205+
assert instance._format_textual_default(" ", True, False) == ""
1206+
assert instance._format_textual_default("NULL", True, False) == "NULL"
1207+
1208+
def test_format_textual_default_mariadb_preserves_unparseable(self, mocker: MockerFixture) -> None:
1209+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1210+
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)
1211+
1212+
result = instance._format_textual_default("json_array()", True, True)
1213+
1214+
assert result == "json_array()"
1215+
1216+
def test_format_textual_default_preserves_parenthesised_unparseable(self, mocker: MockerFixture) -> None:
1217+
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
1218+
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)
1219+
1220+
result = instance._format_textual_default("(select 1)", True, False)
1221+
1222+
assert result == "(select 1)"
1223+
10291224
def test_truncate_table_executes_when_table_exists(self, mocker: MockerFixture) -> None:
10301225
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
10311226
cursor = mocker.MagicMock()

0 commit comments

Comments
 (0)