@@ -26,7 +26,7 @@ def __init__(self, model: Type[Model]):
26
26
self .model = model
27
27
self .primary_key = self ._get_primary_key ()
28
28
29
- def _get_primary_key (self ) -> Column :
29
+ def _get_primary_key (self ) -> Column | list [ Column ] :
30
30
"""
31
31
Dynamically retrieve the primary key column(s) for the model.
32
32
"""
@@ -35,7 +35,21 @@ def _get_primary_key(self) -> Column:
35
35
if len (primary_key ) == 1 :
36
36
return primary_key [0 ]
37
37
else :
38
- raise CompositePrimaryKeysError ('Composite primary keys are not supported' )
38
+ return list (primary_key )
39
+
40
+ def _get_pk_filter (self , pk : Any | Sequence [Any ]) -> list [bool ]:
41
+ """
42
+ Get the primary key filter(s).
43
+
44
+ :param pk: Single value for simple primary key, or tuple for composite primary key.
45
+ :return:
46
+ """
47
+ if isinstance (self .primary_key , list ):
48
+ if len (pk ) != len (self .primary_key ):
49
+ raise CompositePrimaryKeysError (f'Expected { len (self .primary_key )} values for composite primary key' )
50
+ return [column == value for column , value in zip (self .primary_key , pk )]
51
+ else :
52
+ return [self .primary_key == pk ]
39
53
40
54
async def create_model (
41
55
self ,
@@ -55,10 +69,7 @@ async def create_model(
55
69
:param kwargs: Additional model data not included in the pydantic schema.
56
70
:return:
57
71
"""
58
- if not kwargs :
59
- ins = self .model (** obj .model_dump ())
60
- else :
61
- ins = self .model (** obj .model_dump (), ** kwargs )
72
+ ins = self .model (** obj .model_dump ()) if not kwargs else self .model (** obj .model_dump (), ** kwargs )
62
73
63
74
session .add (ins )
64
75
@@ -89,10 +100,7 @@ async def create_models(
89
100
"""
90
101
ins_list = []
91
102
for obj in objs :
92
- if not kwargs :
93
- ins = self .model (** obj .model_dump ())
94
- else :
95
- ins = self .model (** obj .model_dump (), ** kwargs )
103
+ ins = self .model (** obj .model_dump ()) if not kwargs else self .model (** obj .model_dump (), ** kwargs )
96
104
ins_list .append (ins )
97
105
98
106
session .add_all (ins_list )
@@ -118,12 +126,12 @@ async def count(
118
126
:param kwargs: Query expressions.
119
127
:return:
120
128
"""
121
- filter_list = list (whereclause )
129
+ filters = list (whereclause )
122
130
123
131
if kwargs :
124
- filter_list .extend (parse_filters (self .model , ** kwargs ))
132
+ filters .extend (parse_filters (self .model , ** kwargs ))
125
133
126
- stmt = select (func .count ()).select_from (self .model ).where (* filter_list )
134
+ stmt = select (func .count ()).select_from (self .model ).where (* filters )
127
135
query = await session .execute (stmt )
128
136
total_count = query .scalar ()
129
137
return total_count if total_count is not None else 0
@@ -154,21 +162,20 @@ async def exists(
154
162
async def select_model (
155
163
self ,
156
164
session : AsyncSession ,
157
- pk : int ,
165
+ pk : Any | Sequence [ Any ] ,
158
166
* whereclause : ColumnExpressionArgument [bool ],
159
167
) -> Model | None :
160
168
"""
161
- Query by ID
169
+ Query by primary key(s)
162
170
163
171
:param session: The SQLAlchemy async session.
164
- :param pk: The database primary key value .
172
+ :param pk: Single value for simple primary key, or tuple for composite primary key .
165
173
:param whereclause: The WHERE clauses to apply to the query.
166
174
:return:
167
175
"""
168
- filter_list = list (whereclause )
169
- _filters = [self .primary_key == pk ]
170
- _filters .extend (filter_list )
171
- stmt = select (self .model ).where (* _filters )
176
+ filters = self ._get_pk_filter (pk )
177
+ filters + list (whereclause )
178
+ stmt = select (self .model ).where (* filters )
172
179
query = await session .execute (stmt )
173
180
return query .scalars ().first ()
174
181
@@ -186,10 +193,8 @@ async def select_model_by_column(
186
193
:param kwargs: Query expressions.
187
194
:return:
188
195
"""
189
- filter_list = list (whereclause )
190
- _filters = parse_filters (self .model , ** kwargs )
191
- _filters .extend (filter_list )
192
- stmt = select (self .model ).where (* _filters )
196
+ filters = parse_filters (self .model , ** kwargs ) + list (whereclause )
197
+ stmt = select (self .model ).where (* filters )
193
198
query = await session .execute (stmt )
194
199
return query .scalars ().first ()
195
200
@@ -201,10 +206,8 @@ async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -
201
206
:param kwargs: Query expressions.
202
207
:return:
203
208
"""
204
- filter_list = list (whereclause )
205
- _filters = parse_filters (self .model , ** kwargs )
206
- _filters .extend (filter_list )
207
- stmt = select (self .model ).where (* _filters )
209
+ filters = parse_filters (self .model , ** kwargs ) + list (whereclause )
210
+ stmt = select (self .model ).where (* filters )
208
211
return stmt
209
212
210
213
async def select_order (
@@ -270,7 +273,7 @@ async def select_models_order(
270
273
async def update_model (
271
274
self ,
272
275
session : AsyncSession ,
273
- pk : int ,
276
+ pk : Any | Sequence [ Any ] ,
274
277
obj : UpdateSchema | dict [str , Any ],
275
278
flush : bool = False ,
276
279
commit : bool = False ,
@@ -280,21 +283,17 @@ async def update_model(
280
283
Update an instance by model's primary key
281
284
282
285
:param session: The SQLAlchemy async session.
283
- :param pk: The database primary key value .
286
+ :param pk: Single value for simple primary key, or tuple for composite primary key .
284
287
:param obj: A pydantic schema or dictionary containing the update data
285
288
:param flush: If `True`, flush all object changes to the database. Default is `False`.
286
289
:param commit: If `True`, commits the transaction immediately. Default is `False`.
287
290
:param kwargs: Additional model data not included in the pydantic schema.
288
291
:return:
289
292
"""
290
- if isinstance (obj , dict ):
291
- instance_data = obj
292
- else :
293
- instance_data = obj .model_dump (exclude_unset = True )
294
- if kwargs :
295
- instance_data .update (kwargs )
296
-
297
- stmt = update (self .model ).where (self .primary_key == pk ).values (** instance_data )
293
+ filters = self ._get_pk_filter (pk )
294
+ instance_data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
295
+ instance_data .update (kwargs )
296
+ stmt = update (self .model ).where (* filters ).values (** instance_data )
298
297
result = await session .execute (stmt )
299
298
300
299
if flush :
@@ -325,15 +324,13 @@ async def update_model_by_column(
325
324
:return:
326
325
"""
327
326
filters = parse_filters (self .model , ** kwargs )
327
+
328
328
total_count = await self .count (session , * filters )
329
329
if not allow_multiple and total_count > 1 :
330
330
raise MultipleResultsError (f'Only one record is expected to be update, found { total_count } records.' )
331
- if isinstance (obj , dict ):
332
- instance_data = obj
333
- else :
334
- instance_data = obj .model_dump (exclude_unset = True )
335
331
336
- stmt = update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
332
+ instance_data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
333
+ stmt = update (self .model ).where (* filters ).values (** instance_data )
337
334
result = await session .execute (stmt )
338
335
339
336
if flush :
@@ -346,20 +343,22 @@ async def update_model_by_column(
346
343
async def delete_model (
347
344
self ,
348
345
session : AsyncSession ,
349
- pk : int ,
346
+ pk : Any | Sequence [ Any ] ,
350
347
flush : bool = False ,
351
348
commit : bool = False ,
352
349
) -> int :
353
350
"""
354
351
Delete an instance by model's primary key
355
352
356
353
:param session: The SQLAlchemy async session.
357
- :param pk: The database primary key value .
354
+ :param pk: Single value for simple primary key, or tuple for composite primary key .
358
355
:param flush: If `True`, flush all object changes to the database. Default is `False`.
359
356
:param commit: If `True`, commits the transaction immediately. Default is `False`.
360
357
:return:
361
358
"""
362
- stmt = delete (self .model ).where (self .primary_key == pk )
359
+ filters = self ._get_pk_filter (pk )
360
+
361
+ stmt = delete (self .model ).where (* filters )
363
362
result = await session .execute (stmt )
364
363
365
364
if flush :
@@ -392,14 +391,16 @@ async def delete_model_by_column(
392
391
:return:
393
392
"""
394
393
filters = parse_filters (self .model , ** kwargs )
394
+
395
395
total_count = await self .count (session , * filters )
396
396
if not allow_multiple and total_count > 1 :
397
397
raise MultipleResultsError (f'Only one record is expected to be delete, found { total_count } records.' )
398
- if logical_deletion :
399
- deleted_flag = {deleted_flag_column : True }
400
- stmt = update (self .model ).where (* filters ).values (** deleted_flag )
401
- else :
402
- stmt = delete (self .model ).where (* filters )
398
+
399
+ stmt = (
400
+ update (self .model ).where (* filters ).values (** {deleted_flag_column : True })
401
+ if logical_deletion
402
+ else delete (self .model ).where (* filters )
403
+ )
403
404
404
405
result = await session .execute (stmt )
405
406
0 commit comments