diff --git a/sqlalchemy_batch_inserts/__init__.py b/sqlalchemy_batch_inserts/__init__.py index e3e7289..384ee94 100644 --- a/sqlalchemy_batch_inserts/__init__.py +++ b/sqlalchemy_batch_inserts/__init__.py @@ -4,7 +4,7 @@ import sqlalchemy -__version__ = "0.0.4" +__version__ = "0.0.5" def _group_models_by_base_mapper(initial_models): @@ -19,7 +19,7 @@ def _get_column_python_type(column): def _has_normal_id_primary_key(base_mapper): - """Check if the primary key for base_mapper is an auto-incrementing integer `id` column""" + """Check if the primary key for base_mapper is an auto-incrementing integer column""" primary_key_cols = base_mapper.primary_key if len(primary_key_cols) != 1: return False @@ -32,16 +32,23 @@ def _has_normal_id_primary_key(base_mapper): python_column_type = None return ( - primary_key_col.name == "id" - and python_column_type == int - and primary_key_col.autoincrement in ("auto", True) - and primary_key_col.table == base_mapper.local_table + python_column_type == int + and primary_key_col.autoincrement in ("auto", True) + and primary_key_col.table == base_mapper.local_table ) +def _get_primary_key_name(base_mapper): + primary_key_cols = base_mapper.primary_key + [primary_key_col] = primary_key_cols + return primary_key_col.name + + def _get_id_sequence_name(base_mapper): - assert _has_normal_id_primary_key(base_mapper), "_get_id_sequence_name only supports id primary keys" - return "%s_id_seq" % base_mapper.entity.__tablename__ + assert _has_normal_id_primary_key(base_mapper), \ + "_get_id_sequence_name only supports single auto increment primary keys" + primary_key = _get_primary_key_name(base_mapper) + return f"{base_mapper.entity.__tablename__}_{primary_key}_seq" def tuples_to_scalar_list(tuples): @@ -153,11 +160,13 @@ def batch_populate_primary_keys( if skip_unsupported_models: continue else: - raise AssertionError("Expected models to have auto-incrementing `id` primary key") + raise AssertionError("Expected models to have auto-incrementing primary key") + + primary_key = _get_primary_key_name(base_mapper) # In general, batch_populate_primary_keys shouldn't assume anything about how people are creating # models - it is possible for models to have their ids already specified. - models = [model for model in models if model.id is None] + models = [model for model in models if getattr(model, primary_key) is None] if skip_if_single_model and len(models) <= 1: continue @@ -166,7 +175,7 @@ def batch_populate_primary_keys( models = sorted(models, key=lambda model: sqlalchemy.inspect(model).insert_order) for id_, model in zip(new_ids, models): - model.id = id_ + setattr(model, primary_key, id_) def enable_batch_inserting(sqla_session):