diff --git a/src/sqlite3_to_mysql/transporter.py b/src/sqlite3_to_mysql/transporter.py index 905008f..290c4a5 100644 --- a/src/sqlite3_to_mysql/transporter.py +++ b/src/sqlite3_to_mysql/transporter.py @@ -46,7 +46,6 @@ from .mysql_utils import ( MYSQL_BLOB_COLUMN_TYPES, MYSQL_COLUMN_TYPES, - MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT, MYSQL_INSERT_METHOD, MYSQL_TEXT_COLUMN_TYPES, MYSQL_TEXT_COLUMN_TYPES_WITH_JSON, @@ -109,6 +108,8 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]): self._mysql_port = kwargs.get("mysql_port", 3306) or 3306 + self._is_mariadb = False + if kwargs.get("mysql_socket") is not None: if not os.path.exists(str(kwargs.get("mysql_socket"))): raise FileNotFoundError("MySQL socket does not exist") @@ -231,6 +232,7 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]): raise self._mysql_version = self._get_mysql_version() + self._is_mariadb = "-mariadb" in self._mysql_version.lower() self._mysql_json_support = check_mysql_json_support(self._mysql_version) self._mysql_fulltext_support = check_mysql_fulltext_support(self._mysql_version) self._allow_expr_defaults = check_mysql_expression_defaults_support(self._mysql_version) @@ -329,6 +331,69 @@ def _create_database(self) -> None: def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]: return cls.COLUMN_PATTERN.match(column_type.strip()) + @classmethod + def _base_mysql_column_type(cls, column_type: str) -> str: + stripped: str = column_type.strip() + if not stripped: + return "" + match = cls._valid_column_type(stripped) + if match: + return match.group(0).strip().upper() + return stripped.split("(", 1)[0].strip().upper() + + def _column_type_supports_default(self, base_type: str, allow_expr_defaults: bool) -> bool: + normalized: str = base_type.upper() + if not normalized: + return True + if normalized == "GEOMETRY": + return False + if normalized in MYSQL_BLOB_COLUMN_TYPES: + return False + if normalized in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON: + return allow_expr_defaults + return True + + @staticmethod + def _parse_sql_expression(value: str) -> t.Optional[exp.Expression]: + stripped: str = value.strip() + if not stripped: + return None + for dialect in ("mysql", "sqlite"): + try: + return sqlglot.parse_one(stripped, read=dialect) + except sqlglot_errors.ParseError: + continue + return None + + def _format_textual_default( + self, + default_sql: str, + allow_expr_defaults: bool, + is_mariadb: bool, + ) -> str: + """Normalise textual DEFAULT expressions and wrap for MySQL via sqlglot.""" + stripped: str = default_sql.strip() + if not stripped or stripped.upper() == "NULL": + return stripped + if not allow_expr_defaults: + return stripped + + expr: t.Optional[exp.Expression] = self._parse_sql_expression(stripped) + if expr is None: + if is_mariadb or stripped.startswith("("): + return stripped + return f"({stripped})" + + formatted: str = expr.sql(dialect="mysql") + if is_mariadb: + return formatted + + if isinstance(expr, exp.Paren): + return formatted + + wrapped = exp.Paren(this=expr.copy()) + return wrapped.sql(dialect="mysql") + def _translate_type_from_sqlite_to_mysql(self, column_type: str) -> str: normalized: t.Optional[str] = self._normalize_sqlite_column_type(column_type) if normalized and normalized.upper() != column_type.upper(): @@ -804,16 +869,25 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa column["pk"] > 0 and column_type.startswith(("INT", "BIGINT")) and not compound_primary_key ) + allow_expr_defaults: bool = getattr(self, "_allow_expr_defaults", False) + is_mariadb: bool = getattr(self, "_is_mariadb", False) + base_type: str = self._base_mysql_column_type(column_type) + # Build DEFAULT clause safely (preserve falsy defaults like 0/'') default_clause: str = "" if ( not skip_default and column["dflt_value"] is not None - and column_type not in MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT + and self._column_type_supports_default(base_type, allow_expr_defaults) and not auto_increment ): td: str = self._translate_default_for_mysql(column_type, str(column["dflt_value"])) if td != "": + stripped_td: str = td.strip() + if base_type in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON and stripped_td.upper() != "NULL": + td = self._format_textual_default(stripped_td, allow_expr_defaults, is_mariadb) + else: + td = stripped_td default_clause = "DEFAULT " + td sql += " `{name}` {type} {notnull} {default} {auto_increment}, ".format( name=mysql_safe_name, diff --git a/src/sqlite3_to_mysql/types.py b/src/sqlite3_to_mysql/types.py index 0344658..6b74075 100644 --- a/src/sqlite3_to_mysql/types.py +++ b/src/sqlite3_to_mysql/types.py @@ -87,6 +87,7 @@ class SQLite3toMySQLAttributes: _mysql: MySQLConnection _mysql_cur: MySQLCursor _mysql_version: str + _is_mariadb: bool _mysql_json_support: bool _mysql_fulltext_support: bool _allow_expr_defaults: bool diff --git a/tests/unit/sqlite3_to_mysql_test.py b/tests/unit/sqlite3_to_mysql_test.py index 5cb67bf..358dd20 100644 --- a/tests/unit/sqlite3_to_mysql_test.py +++ b/tests/unit/sqlite3_to_mysql_test.py @@ -1026,6 +1026,201 @@ def test_create_table_invalid_default_retries_without_defaults(self, mocker: Moc assert "DEFAULT CURRENT_TIMESTAMP" not in retry_sql instance._logger.warning.assert_called_once() + def test_create_table_text_default_mariadb(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + instance._sqlite_table_xinfo_support = False + instance._sqlite_quote_ident = lambda name: name.replace('"', '""') + instance._mysql_charset = "utf8mb4" + instance._mysql_collation = "utf8mb4_unicode_ci" + instance._logger = mocker.MagicMock() + instance._allow_expr_defaults = True + instance._is_mariadb = True + + rows = [ + {"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0}, + ] + + sqlite_cursor = mocker.MagicMock() + sqlite_cursor.fetchall.return_value = rows + instance._sqlite_cur = sqlite_cursor + + instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT") + + mysql_cursor = mocker.MagicMock() + instance._mysql_cur = mysql_cursor + instance._mysql = mocker.MagicMock() + + instance._create_table("demo") + + executed_sql = mysql_cursor.execute.call_args[0][0] + assert "DEFAULT '[]'" in executed_sql + assert "DEFAULT ('[]')" not in executed_sql + + def test_create_table_text_default_mysql_expression(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + instance._sqlite_table_xinfo_support = False + instance._sqlite_quote_ident = lambda name: name.replace('"', '""') + instance._mysql_charset = "utf8mb4" + instance._mysql_collation = "utf8mb4_unicode_ci" + instance._logger = mocker.MagicMock() + instance._allow_expr_defaults = True + instance._is_mariadb = False + + rows = [ + {"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0}, + ] + + sqlite_cursor = mocker.MagicMock() + sqlite_cursor.fetchall.return_value = rows + instance._sqlite_cur = sqlite_cursor + + instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT") + + mysql_cursor = mocker.MagicMock() + instance._mysql_cur = mysql_cursor + instance._mysql = mocker.MagicMock() + + instance._create_table("demo") + + executed_sql = mysql_cursor.execute.call_args[0][0] + assert "DEFAULT ('[]')" in executed_sql + + def test_create_table_text_default_mysql_function_expression(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + instance._sqlite_table_xinfo_support = False + instance._sqlite_quote_ident = lambda name: name.replace('"', '""') + instance._mysql_charset = "utf8mb4" + instance._mysql_collation = "utf8mb4_unicode_ci" + instance._logger = mocker.MagicMock() + instance._allow_expr_defaults = True + instance._is_mariadb = False + + rows = [ + {"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "json_array()", "pk": 0}, + ] + + sqlite_cursor = mocker.MagicMock() + sqlite_cursor.fetchall.return_value = rows + instance._sqlite_cur = sqlite_cursor + + instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT") + instance._translate_default_for_mysql = mocker.MagicMock(return_value="JSON_ARRAY()") + + mysql_cursor = mocker.MagicMock() + instance._mysql_cur = mysql_cursor + instance._mysql = mocker.MagicMock() + + instance._create_table("demo") + + executed_sql = mysql_cursor.execute.call_args[0][0] + assert "DEFAULT (JSON_ARRAY())" in executed_sql + + def test_parse_sql_expression_falls_back_to_sqlite(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + sqlite_expr = exp.Literal.string("ok") + parse_mock = mocker.patch( + "sqlite3_to_mysql.transporter.sqlglot.parse_one", + side_effect=[sqlglot_errors.ParseError("mysql"), sqlite_expr], + ) + + result = instance._parse_sql_expression("value") + + assert result is sqlite_expr + assert parse_mock.call_args_list[0].kwargs["read"] == "mysql" + assert parse_mock.call_args_list[1].kwargs["read"] == "sqlite" + + def test_parse_sql_expression_returns_none_when_unparseable(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + parse_mock = mocker.patch( + "sqlite3_to_mysql.transporter.sqlglot.parse_one", + side_effect=[ + sqlglot_errors.ParseError("mysql"), + sqlglot_errors.ParseError("sqlite"), + ], + ) + + result = instance._parse_sql_expression("value") + + assert result is None + assert parse_mock.call_count == 2 + + def test_format_textual_default_wraps_when_unparseable_mysql(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + mocker.patch.object(instance, "_parse_sql_expression", return_value=None) + + result = instance._format_textual_default("raw_json()", True, False) + + assert result == "(raw_json())" + + def test_format_textual_default_mariadb_uses_literal_output(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + literal_expr = exp.Literal.string("[]") + mocker.patch.object(instance, "_parse_sql_expression", return_value=literal_expr) + + result = instance._format_textual_default("'[]'", True, True) + + assert result == "'[]'" + + def test_format_textual_default_preserves_existing_parens(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + paren_expr = exp.Paren(this=exp.Literal.string("[]")) + mocker.patch.object(instance, "_parse_sql_expression", return_value=paren_expr) + + result = instance._format_textual_default("('[]')", True, False) + + assert result == "('[]')" + + def test_format_textual_default_respects_disabled_expression_defaults(self) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + + result = instance._format_textual_default("'[]'", False, False) + + assert result == "'[]'" + + def test_base_mysql_column_type_handles_whitespace_and_unknown(self) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + + assert instance._base_mysql_column_type(" TEXT(255) ") == "TEXT" + assert instance._base_mysql_column_type("custom_type") == "CUSTOM_TYPE" + assert instance._base_mysql_column_type("(TEXT)") == "" + assert instance._base_mysql_column_type(" ") == "" + + def test_column_type_supports_default_branches(self) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + + assert not instance._column_type_supports_default("GEOMETRY", True) + assert not instance._column_type_supports_default("BLOB", True) + assert not instance._column_type_supports_default("TEXT", False) + assert instance._column_type_supports_default("", True) + assert instance._column_type_supports_default("VARCHAR", False) + + def test_parse_sql_expression_returns_none_for_blank(self) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + + assert instance._parse_sql_expression(" ") is None + + def test_format_textual_default_handles_blank_and_null(self) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + + assert instance._format_textual_default(" ", True, False) == "" + assert instance._format_textual_default("NULL", True, False) == "NULL" + + def test_format_textual_default_mariadb_preserves_unparseable(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + mocker.patch.object(instance, "_parse_sql_expression", return_value=None) + + result = instance._format_textual_default("json_array()", True, True) + + assert result == "json_array()" + + def test_format_textual_default_preserves_parenthesised_unparseable(self, mocker: MockerFixture) -> None: + instance = SQLite3toMySQL.__new__(SQLite3toMySQL) + mocker.patch.object(instance, "_parse_sql_expression", return_value=None) + + result = instance._format_textual_default("(select 1)", True, False) + + assert result == "(select 1)" + def test_truncate_table_executes_when_table_exists(self, mocker: MockerFixture) -> None: instance = SQLite3toMySQL.__new__(SQLite3toMySQL) cursor = mocker.MagicMock()