Skip to content

Commit b57020d

Browse files
committed
feat: Add a more ergonomic API for enabling auditing on models.
1 parent e4f9d94 commit b57020d

File tree

6 files changed

+133
-18
lines changed

6 files changed

+133
-18
lines changed

src/sqlalchemy_postgresql_audit/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"uninstall_audit_triggers",
88
]
99

10-
from .session import set_session_vars
11-
from .plugin import enable
10+
from .api import audit_model
1211
from .ddl import install_audit_triggers, uninstall_audit_triggers
12+
from .plugin import enable
13+
from .session import set_session_vars
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import functools
2+
import uuid
3+
from sqlalchemy.dialects.postgresql import UUID
4+
from sqlalchemy import Column, text
5+
6+
try:
7+
from sqlalchemy.orm.decl_api import DeclarativeMeta
8+
except ImportError:
9+
from sqlalchemy.ext.declarative.api import DeclarativeMeta
10+
11+
from sqlalchemy_postgresql_audit.event_listeners.sqlalchemy import create_audit_table
12+
13+
14+
def audit_model(_func=None, *, enabled=True, **spec):
15+
def decorated(model_cls):
16+
create_audit_model(model_cls, enabled=enabled, **spec)
17+
18+
return model_cls
19+
20+
if _func is None:
21+
return decorated
22+
return decorated(_func)
23+
24+
25+
def create_audit_model(model_cls, *, enabled=True, primary_key=None, **spec):
26+
table = model_cls.__table__
27+
metadata = model_cls.metadata
28+
29+
table = create_audit_table(table, metadata, enabled=enabled, primary_key=primary_key, **spec)
30+
model_base = _find_model_base(model_cls)
31+
32+
cls = type(
33+
'{model_cls}Audit'.format(model_cls=model_cls.__name__),
34+
(model_base,),
35+
{'__table__': table},
36+
)
37+
38+
return cls
39+
40+
41+
def create_audit_table(table, metadata, *, enabled=True, primary_key=None, **spec):
42+
existing_info = table.info
43+
existing_info['audit.options'] = {'enabled': enabled, **spec}
44+
45+
if primary_key is None:
46+
primary_key = Column(
47+
'audit_uuid',
48+
UUID(as_uuid=True),
49+
primary_key=True,
50+
default=uuid.uuid4,
51+
server_default=text("uuid_generate_v4()"),
52+
)
53+
54+
return create_audit_table(table, metadata, primary_key=primary_key)
55+
56+
57+
def _find_model_base(model_cls):
58+
for cls in model_cls.__mro__:
59+
if isinstance(cls, DeclarativeMeta) and not hasattr(cls, '__mapper__'):
60+
return cls
61+
62+
raise ValueError("Invalid model, does not subclass a `DeclarativeMeta`.")

src/sqlalchemy_postgresql_audit/ddl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def get_create_trigger_ddl(
5656
# or it is one of our session settings values
5757
else:
5858
if col.name in (
59+
"audit_uuid",
5960
"audit_operation",
6061
"audit_operation_timestamp",
6162
"audit_current_user",
Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import threading
22

3-
from sqlalchemy import Table
4-
from sqlalchemy.events import event
3+
from sqlalchemy import Table, event
54

65
_event_listeners_enabled = False
76

@@ -17,20 +16,19 @@ def enable_event_listeners():
1716

1817

1918
def _enable_sqlalchemy_event_listeners():
20-
from sqlalchemy_postgresql_audit.event_listeners.sqlalchemy import (
21-
create_audit_table,
22-
)
19+
from sqlalchemy_postgresql_audit.event_listeners.sqlalchemy import \
20+
create_audit_table
2321

2422
event.listens_for(Table, "after_parent_attach")(create_audit_table)
2523

2624

2725
def _enable_alembic_event_listeners():
2826
try:
29-
from sqlalchemy_postgresql_audit.event_listeners.alembic import (
30-
compare_for_table,
31-
)
3227
from alembic.autogenerate.compare import comparators
3328

29+
from sqlalchemy_postgresql_audit.event_listeners.alembic import \
30+
compare_for_table
31+
3432
comparators.dispatch_for("table")(compare_for_table)
3533
except ImportError:
3634
pass

src/sqlalchemy_postgresql_audit/event_listeners/sqlalchemy.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020

21-
def create_audit_table(target, parent):
21+
def create_audit_table(target, parent, primary_key=None):
2222
"""Create an audit table and generate procedure/trigger DDL.
2323
2424
Naming conventions can be defined for a few of the named elements:
@@ -83,21 +83,29 @@ def create_audit_table(target, parent):
8383
"schema": audit_spec["schema"] or "public",
8484
}
8585

86-
columns = [
87-
Column(col.name, col.type, nullable=True) for col in target.columns.values()
88-
]
86+
column_elements = []
87+
if primary_key is not None:
88+
column_elements.append(primary_key)
89+
90+
column_elements.extend([
91+
Column("audit_operation", String(1), nullable=False),
92+
Column("audit_operation_timestamp", DateTime, nullable=False),
93+
Column("audit_current_user", String(64), nullable=False),
94+
])
95+
8996
session_setting_columns = [col.copy() for col in audit_spec["session_settings"]]
9097
for col in session_setting_columns:
9198
col.name = "audit_{}".format(col.name)
99+
column_elements.extend(session_setting_columns)
92100

93-
column_elements = session_setting_columns + columns
101+
table_columns = [
102+
Column(col.name, col.type, nullable=True) for col in target.columns.values()
103+
]
104+
column_elements.extend(table_columns)
94105

95106
audit_table = Table(
96107
audit_table_name,
97108
target.metadata,
98-
Column("audit_operation", String(1), nullable=False),
99-
Column("audit_operation_timestamp", DateTime, nullable=False),
100-
Column("audit_current_user", String(64), nullable=False),
101109
*column_elements,
102110
schema=audit_spec["schema"]
103111
)
@@ -119,3 +127,4 @@ def create_audit_table(target, parent):
119127

120128
audit_table.info["audit.is_audit_table"] = True
121129
target.info["audit.is_audited"] = True
130+
return audit_table
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from sqlalchemy import Column, Integer, MetaData, Table
2+
from sqlalchemy.ext.declarative import declarative_base
3+
4+
import sqlalchemy_postgresql_audit
5+
6+
7+
def setup():
8+
sqlalchemy_postgresql_audit.enable()
9+
10+
11+
def test_vanilla_model():
12+
Base = declarative_base()
13+
metadata = Base.metadata
14+
15+
@sqlalchemy_postgresql_audit.audit_model
16+
class Model(Base):
17+
__tablename__ = 'foo'
18+
19+
id = Column("id", Integer, primary_key=True)
20+
21+
audit_table = metadata.tables["foo_audit"]
22+
23+
assert audit_table.info["audit.is_audit_table"]
24+
assert Model.__table__.info["audit.is_audited"]
25+
26+
27+
def test_model_with_info():
28+
Base = declarative_base()
29+
metadata = Base.metadata
30+
31+
@sqlalchemy_postgresql_audit.audit_model
32+
class Model(Base):
33+
__tablename__ = 'foo'
34+
__table_args__ = {
35+
'info': {'example': 4}
36+
}
37+
38+
id = Column("id", Integer, primary_key=True)
39+
40+
audit_table = metadata.tables["foo_audit"]
41+
42+
assert audit_table.info["audit.is_audit_table"]
43+
assert Model.__table__.info["audit.is_audited"]
44+
assert 'example' in Model.__table__.info

0 commit comments

Comments
 (0)