Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions df_to_azure/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,87 @@


class SqlUpsert:
def __init__(self, table_name, schema, id_cols, columns):
def __init__(self, table_name, schema, id_cols, columns, preserve_identity=False):
self.table_name = table_name
self.schema = schema
self.id_cols = id_cols
self.columns = [col.strip() for col in columns]
self.preserve_identity = preserve_identity
self.identity_columns = []

def get_identity_columns(self):
"""
Query SQL Server to detect IDENTITY (auto-increment) columns in the target table.

Returns
-------
list
List of column names that have IDENTITY property in the target table.
"""
query = text(f"""
SELECT c.name
FROM sys.identity_columns ic
INNER JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id
INNER JOIN sys.tables t ON ic.object_id = t.object_id
INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
WHERE t.name = '{self.table_name}' AND s.name = '{self.schema}'
""")

with auth_azure() as con:
result = con.execute(query)
return [row[0] for row in result]

def validate_identity_usage(self):
"""
Validate that user isn't trying to upsert on IDENTITY columns without explicit permission.

This method checks if any of the id_cols (columns used for matching in upsert) are IDENTITY
columns. If so, and preserve_identity=False, it raises an informative error with alternatives.

If preserve_identity=True and IDENTITY columns are in id_cols, logs a warning about the risks.

Raises
------
UpsertError
If id_cols contain IDENTITY columns and preserve_identity=False.
"""
self.identity_columns = self.get_identity_columns()

# Check if any id_cols are IDENTITY columns
identity_in_id_cols = [col for col in self.id_cols if col in self.identity_columns]

if identity_in_id_cols and not self.preserve_identity:
# Build helpful error message based on the scenario
if len(self.id_cols) == 1 and len(identity_in_id_cols) == 1:
# Scenario A: IDENTITY is the only id_field
raise UpsertError(
f"Column '{identity_in_id_cols[0]}' is an auto-increment (IDENTITY) column "
f"and cannot be used for upsert matching.\n\n"
f"Suggested alternatives:\n"
f"1. Use method='append' instead if you want to insert new records with auto-generated IDs\n"
f"2. Add a business key column (e.g., 'user_email', 'external_id') and use that for id_field\n"
f"3. If you must preserve existing ID values (e.g., data migration), set preserve_identity=True\n"
f" WARNING: Using preserve_identity=True is not recommended as it can break ID sequence generation"
)
else:
# Scenario B: IDENTITY is part of composite key
other_cols = [col for col in self.id_cols if col not in identity_in_id_cols]
raise UpsertError(
f"Column(s) {identity_in_id_cols} are auto-increment (IDENTITY) columns "
f"and are part of your id_field {self.id_cols}.\n\n"
f"Suggested alternatives:\n"
f"1. Remove IDENTITY column(s) from id_field and use only: {other_cols}\n"
f"2. If you must preserve existing ID values (e.g., data migration), set preserve_identity=True\n"
f" WARNING: Using preserve_identity=True is not recommended as it can break ID sequence generation"
)

# If preserve_identity=True and IDENTITY columns are in id_cols, log warning
if self.preserve_identity and identity_in_id_cols:
logging.warning(
f"preserve_identity=True: IDENTITY_INSERT will be enabled for {self.schema}.{self.table_name}. "
f"This is not recommended and may cause ID sequence issues. "
f"Consider using non-IDENTITY columns for id_field instead."
)

def create_on_statement(self):
on = " AND ".join([f"s.[{id_col}] = t.[{id_col}]" for id_col in self.id_cols])
Expand All @@ -34,29 +110,67 @@ def create_insert_statement(self):
return insert, values

def create_merge_query(self):
"""
Generate MERGE statement with optional IDENTITY_INSERT handling.

If preserve_identity=True, wraps the MERGE statement with
SET IDENTITY_INSERT ON/OFF to allow explicit insertion of IDENTITY values.

Returns
-------
text
SQLAlchemy text object containing the CREATE PROCEDURE statement.
"""
insert = self.create_insert_statement()
query = f"""
CREATE PROCEDURE [UPSERT_{self.table_name}]
AS
MERGE {self.schema}.{self.table_name} t

merge_stmt = f"""MERGE {self.schema}.{self.table_name} t
USING staging.{self.table_name} s
ON {self.create_on_statement()}
WHEN MATCHED
THEN UPDATE SET
{self.create_update_statement()}
WHEN NOT MATCHED BY TARGET
THEN INSERT {insert[0]}
VALUES {insert[1]};
VALUES {insert[1]};"""

if self.preserve_identity:
# Wrap with IDENTITY_INSERT ON/OFF
query = f"""
CREATE PROCEDURE [UPSERT_{self.table_name}]
AS
SET IDENTITY_INSERT {self.schema}.{self.table_name} ON;
{merge_stmt}
SET IDENTITY_INSERT {self.schema}.{self.table_name} OFF;
"""
else:
query = f"""
CREATE PROCEDURE [UPSERT_{self.table_name}]
AS
{merge_stmt}
"""
logging.debug(query)

logging.debug(query)
return text(query)

def drop_procedure(self):
query = f"DROP PROCEDURE IF EXISTS [UPSERT_{self.table_name}];"
return text(query)

def create_stored_procedure(self):
"""
Create the stored procedure for upsert operation.

This method first validates that IDENTITY columns are being used correctly,
then creates the stored procedure with the MERGE statement.

Raises
------
UpsertError
If IDENTITY columns are used incorrectly or if procedure creation fails.
"""
# Validate IDENTITY usage BEFORE creating procedure
self.validate_identity_usage()

with auth_azure() as con:
t = con.begin()
query_drop_procedure = self.drop_procedure()
Expand Down
5 changes: 5 additions & 0 deletions df_to_azure/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def df_to_azure(
parquet=False,
clean_staging=True,
container_name="parquet",
preserve_identity=False,
):
if parquet:
DfToParquet(
Expand All @@ -57,6 +58,7 @@ def df_to_azure(
create=create,
dtypes=dtypes,
clean_staging=clean_staging,
preserve_identity=preserve_identity,
).run()

return adf_client, run_response
Expand All @@ -77,6 +79,7 @@ def __init__(
create: bool = False,
dtypes: dict = None,
clean_staging: bool = True,
preserve_identity: bool = False,
):
super().__init__(
df=df,
Expand All @@ -92,6 +95,7 @@ def __init__(
self.decimal_precision = decimal_precision
self.dtypes = dtypes
self.clean_staging = clean_staging
self.preserve_identity = preserve_identity

def run(self):
if self.df.empty:
Expand Down Expand Up @@ -144,6 +148,7 @@ def upload_dataset(self):
schema=self.schema,
id_cols=self.id_field,
columns=self.df.columns,
preserve_identity=self.preserve_identity,
)
upsert.create_stored_procedure()
self.schema = "staging"
Expand Down
72 changes: 72 additions & 0 deletions df_to_azure/tests/test_identity_insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from pandas import DataFrame, read_sql_table
from pandas._testing import assert_frame_equal

from df_to_azure import df_to_azure
from df_to_azure.db import auth_azure, execute_stmt
from df_to_azure.exceptions import UpsertError

SCHEMA = "test"


def reset_identity_table(table_name: str) -> None:
execute_stmt(
f"""
IF OBJECT_ID('{SCHEMA}.{table_name}', 'U') IS NOT NULL
DROP TABLE [{SCHEMA}].[{table_name}];
CREATE TABLE [{SCHEMA}].[{table_name}](
[id] INT IDENTITY(1,1) NOT NULL PRIMARY KEY,
[value] NVARCHAR(255) NOT NULL
);
"""
)


def insert_values(table_name: str, values: list[str]) -> None:
for value in values:
escaped = value.replace("'", "''")
execute_stmt(f"INSERT INTO [{SCHEMA}].[{table_name}] ([value]) VALUES ('{escaped}')")


def test_upsert_identity_column_requires_preserve():
table_name = "identity_no_preserve"
reset_identity_table(table_name)
insert_values(table_name, ["original value"])

df = DataFrame({"id": [1], "value": ["updated value"]})

with pytest.raises(UpsertError) as excinfo:
df_to_azure(
df=df,
tablename=table_name,
schema=SCHEMA,
method="upsert",
id_field="id",
wait_till_finished=True,
)

assert "Column 'id' is an auto-increment (IDENTITY) column" in str(excinfo.value)


def test_upsert_identity_column_with_preserve_identity():
table_name = "identity_with_preserve"
reset_identity_table(table_name)
insert_values(table_name, ["original value"])

df = DataFrame({"id": [1, 10], "value": ["updated value", "migrated value"]})

df_to_azure(
df=df,
tablename=table_name,
schema=SCHEMA,
method="upsert",
id_field="id",
wait_till_finished=True,
preserve_identity=True,
)

with auth_azure() as con:
result = read_sql_table(table_name=table_name, con=con, schema=SCHEMA).sort_values("id")

expected = DataFrame({"id": [1, 10], "value": ["updated value", "migrated value"]})
assert_frame_equal(expected, result.reset_index(drop=True))