Skip to content

Commit e116c94

Browse files
authored
Add composite primary key support (#44)
* Add composite primary key support * Update docs
1 parent 7d6b511 commit e116c94

13 files changed

+176
-81
lines changed

docs/advanced/primary_key.md

+16-8
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22

33
由于在 python 内部 `id` 为关键字,因此,我们设定默认主键入参为 `pk`。这仅用于函数入参,并不要求模型主键必须定义为 `pk`
44

5-
```py title="e.g." hl_lines="2"
6-
async def delete(self, db: AsyncSession, primary_key: int) -> int:
7-
return self.delete_model(db, pk=primary_key)
8-
```
9-
10-
## 主键定义
11-
125
!!! tip 自动主键
136

147
我们在 SQLAlchemy CRUD Plus 内部通过 [inspect()](https://docs.sqlalchemy.org/en/20/core/inspection.html) 自动搜索表主键,
158
而非强制绑定主键列必须命名为 `id`
169

17-
```py title="e.g." hl_lines="4"
10+
## 单个主键
11+
12+
```py title="e.g."
1813
class ModelIns(Base):
1914
# define primary_key
2015
primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)
16+
17+
18+
class ModelIns2(Base):
19+
# define primary_key
20+
primary_key: Mapped[str] = mapped_column(primary_key=True, index=True)
21+
```
22+
23+
## 复合主键
24+
25+
```python title="e.g."
26+
class ModelIns(Base):
27+
primary_key: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)
28+
primary_key2: Mapped[str] = mapped_column(primary_key=True, index=True)
2129
```

docs/usage/delete_model.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ class CRUDIns(CRUDPlus[ModelIns]):
2424
async def delete_model(
2525
self,
2626
session: AsyncSession,
27-
pk: int,
27+
pk: Any | Sequence[Any],
2828
flush: bool = False,
2929
commit: bool = False,
3030
) -> int:
3131
```
3232

3333
**Parameters:**
3434

35-
| Name | Type | Description | Default |
36-
|---------|--------------|----------------------------------|---------|
37-
| session | AsyncSession | 数据库会话 | 必填 |
38-
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
39-
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
40-
| commit | bool | [提交](../advanced/commit.md) | `False` |
35+
| Name | Type | Description | Default |
36+
|---------|--------------------------|----------------------------------|---------|
37+
| session | AsyncSession | 数据库会话 | 必填 |
38+
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
39+
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
40+
| commit | bool | [提交](../advanced/commit.md) | `False` |
4141

4242
**Returns:**
4343

docs/usage/select_model.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ class CRUDIns(CRUDPlus[ModelIns]):
2424
async def select_model(
2525
self,
2626
session: AsyncSession,
27-
pk: int,
27+
pk: Any | Sequence[Any],
2828
*whereclause: ColumnExpressionArgument[bool],
2929
) -> Model | None:
3030
```
3131

3232
**Parameters:**
3333

34-
| Name | Type | Description | Default |
35-
|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------|
36-
| session | AsyncSession | 数据库会话 | 必填 |
37-
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
34+
| Name | Type | Description | Default |
35+
|--------------|----------------------------------|-----------------------------------------------------------------------------------------------------|---------|
36+
| session | AsyncSession | 数据库会话 | 必填 |
37+
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
3838
| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | |
3939

4040
**Returns:**

docs/usage/update_model.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
3131
async def update_model(
3232
self,
3333
session: AsyncSession,
34-
pk: int,
34+
pk: Any | Sequence[Any],
3535
obj: UpdateSchema | dict[str, Any],
3636
flush: bool = False,
3737
commit: bool = False,
@@ -44,7 +44,7 @@ async def update_model(
4444
| Name | Type | Description | Default |
4545
|---------|-------------------------------|----------------------------------|---------|
4646
| session | AsyncSession | 数据库会话 | 必填 |
47-
| pk | int | [主键](../advanced/primary_key.md) | 必填 |
47+
| pk | `Any `\| `Sequence[Any]` | [主键](../advanced/primary_key.md) | 必填 |
4848
| obj | `TypeVar `\|` dict[str, Any]` | 更新数据参数 | 必填 |
4949
| flush | bool | [冲洗](../advanced/flush.md) | `False` |
5050
| commit | bool | [提交](../advanced/commit.md) | `False` |

sqlalchemy_crud_plus/crud.py

+52-51
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, model: Type[Model]):
2626
self.model = model
2727
self.primary_key = self._get_primary_key()
2828

29-
def _get_primary_key(self) -> Column:
29+
def _get_primary_key(self) -> Column | list[Column]:
3030
"""
3131
Dynamically retrieve the primary key column(s) for the model.
3232
"""
@@ -35,7 +35,21 @@ def _get_primary_key(self) -> Column:
3535
if len(primary_key) == 1:
3636
return primary_key[0]
3737
else:
38-
raise CompositePrimaryKeysError('Composite primary keys are not supported')
38+
return list(primary_key)
39+
40+
def _get_pk_filter(self, pk: Any | Sequence[Any]) -> list[bool]:
41+
"""
42+
Get the primary key filter(s).
43+
44+
:param pk: Single value for simple primary key, or tuple for composite primary key.
45+
:return:
46+
"""
47+
if isinstance(self.primary_key, list):
48+
if len(pk) != len(self.primary_key):
49+
raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key')
50+
return [column == value for column, value in zip(self.primary_key, pk)]
51+
else:
52+
return [self.primary_key == pk]
3953

4054
async def create_model(
4155
self,
@@ -55,10 +69,7 @@ async def create_model(
5569
:param kwargs: Additional model data not included in the pydantic schema.
5670
:return:
5771
"""
58-
if not kwargs:
59-
ins = self.model(**obj.model_dump())
60-
else:
61-
ins = self.model(**obj.model_dump(), **kwargs)
72+
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
6273

6374
session.add(ins)
6475

@@ -89,10 +100,7 @@ async def create_models(
89100
"""
90101
ins_list = []
91102
for obj in objs:
92-
if not kwargs:
93-
ins = self.model(**obj.model_dump())
94-
else:
95-
ins = self.model(**obj.model_dump(), **kwargs)
103+
ins = self.model(**obj.model_dump()) if not kwargs else self.model(**obj.model_dump(), **kwargs)
96104
ins_list.append(ins)
97105

98106
session.add_all(ins_list)
@@ -118,12 +126,12 @@ async def count(
118126
:param kwargs: Query expressions.
119127
:return:
120128
"""
121-
filter_list = list(whereclause)
129+
filters = list(whereclause)
122130

123131
if kwargs:
124-
filter_list.extend(parse_filters(self.model, **kwargs))
132+
filters.extend(parse_filters(self.model, **kwargs))
125133

126-
stmt = select(func.count()).select_from(self.model).where(*filter_list)
134+
stmt = select(func.count()).select_from(self.model).where(*filters)
127135
query = await session.execute(stmt)
128136
total_count = query.scalar()
129137
return total_count if total_count is not None else 0
@@ -154,21 +162,20 @@ async def exists(
154162
async def select_model(
155163
self,
156164
session: AsyncSession,
157-
pk: int,
165+
pk: Any | Sequence[Any],
158166
*whereclause: ColumnExpressionArgument[bool],
159167
) -> Model | None:
160168
"""
161-
Query by ID
169+
Query by primary key(s)
162170
163171
:param session: The SQLAlchemy async session.
164-
:param pk: The database primary key value.
172+
:param pk: Single value for simple primary key, or tuple for composite primary key.
165173
:param whereclause: The WHERE clauses to apply to the query.
166174
:return:
167175
"""
168-
filter_list = list(whereclause)
169-
_filters = [self.primary_key == pk]
170-
_filters.extend(filter_list)
171-
stmt = select(self.model).where(*_filters)
176+
filters = self._get_pk_filter(pk)
177+
filters + list(whereclause)
178+
stmt = select(self.model).where(*filters)
172179
query = await session.execute(stmt)
173180
return query.scalars().first()
174181

@@ -186,10 +193,8 @@ async def select_model_by_column(
186193
:param kwargs: Query expressions.
187194
:return:
188195
"""
189-
filter_list = list(whereclause)
190-
_filters = parse_filters(self.model, **kwargs)
191-
_filters.extend(filter_list)
192-
stmt = select(self.model).where(*_filters)
196+
filters = parse_filters(self.model, **kwargs) + list(whereclause)
197+
stmt = select(self.model).where(*filters)
193198
query = await session.execute(stmt)
194199
return query.scalars().first()
195200

@@ -201,10 +206,8 @@ async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -
201206
:param kwargs: Query expressions.
202207
:return:
203208
"""
204-
filter_list = list(whereclause)
205-
_filters = parse_filters(self.model, **kwargs)
206-
_filters.extend(filter_list)
207-
stmt = select(self.model).where(*_filters)
209+
filters = parse_filters(self.model, **kwargs) + list(whereclause)
210+
stmt = select(self.model).where(*filters)
208211
return stmt
209212

210213
async def select_order(
@@ -270,7 +273,7 @@ async def select_models_order(
270273
async def update_model(
271274
self,
272275
session: AsyncSession,
273-
pk: int,
276+
pk: Any | Sequence[Any],
274277
obj: UpdateSchema | dict[str, Any],
275278
flush: bool = False,
276279
commit: bool = False,
@@ -280,21 +283,17 @@ async def update_model(
280283
Update an instance by model's primary key
281284
282285
:param session: The SQLAlchemy async session.
283-
:param pk: The database primary key value.
286+
:param pk: Single value for simple primary key, or tuple for composite primary key.
284287
:param obj: A pydantic schema or dictionary containing the update data
285288
:param flush: If `True`, flush all object changes to the database. Default is `False`.
286289
:param commit: If `True`, commits the transaction immediately. Default is `False`.
287290
:param kwargs: Additional model data not included in the pydantic schema.
288291
:return:
289292
"""
290-
if isinstance(obj, dict):
291-
instance_data = obj
292-
else:
293-
instance_data = obj.model_dump(exclude_unset=True)
294-
if kwargs:
295-
instance_data.update(kwargs)
296-
297-
stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
293+
filters = self._get_pk_filter(pk)
294+
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
295+
instance_data.update(kwargs)
296+
stmt = update(self.model).where(*filters).values(**instance_data)
298297
result = await session.execute(stmt)
299298

300299
if flush:
@@ -325,15 +324,13 @@ async def update_model_by_column(
325324
:return:
326325
"""
327326
filters = parse_filters(self.model, **kwargs)
327+
328328
total_count = await self.count(session, *filters)
329329
if not allow_multiple and total_count > 1:
330330
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
331-
if isinstance(obj, dict):
332-
instance_data = obj
333-
else:
334-
instance_data = obj.model_dump(exclude_unset=True)
335331

336-
stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore
332+
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
333+
stmt = update(self.model).where(*filters).values(**instance_data)
337334
result = await session.execute(stmt)
338335

339336
if flush:
@@ -346,20 +343,22 @@ async def update_model_by_column(
346343
async def delete_model(
347344
self,
348345
session: AsyncSession,
349-
pk: int,
346+
pk: Any | Sequence[Any],
350347
flush: bool = False,
351348
commit: bool = False,
352349
) -> int:
353350
"""
354351
Delete an instance by model's primary key
355352
356353
:param session: The SQLAlchemy async session.
357-
:param pk: The database primary key value.
354+
:param pk: Single value for simple primary key, or tuple for composite primary key.
358355
:param flush: If `True`, flush all object changes to the database. Default is `False`.
359356
:param commit: If `True`, commits the transaction immediately. Default is `False`.
360357
:return:
361358
"""
362-
stmt = delete(self.model).where(self.primary_key == pk)
359+
filters = self._get_pk_filter(pk)
360+
361+
stmt = delete(self.model).where(*filters)
363362
result = await session.execute(stmt)
364363

365364
if flush:
@@ -392,14 +391,16 @@ async def delete_model_by_column(
392391
:return:
393392
"""
394393
filters = parse_filters(self.model, **kwargs)
394+
395395
total_count = await self.count(session, *filters)
396396
if not allow_multiple and total_count > 1:
397397
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
398-
if logical_deletion:
399-
deleted_flag = {deleted_flag_column: True}
400-
stmt = update(self.model).where(*filters).values(**deleted_flag)
401-
else:
402-
stmt = delete(self.model).where(*filters)
398+
399+
stmt = (
400+
update(self.model).where(*filters).values(**{deleted_flag_column: True})
401+
if logical_deletion
402+
else delete(self.model).where(*filters)
403+
)
403404

404405
result = await session.execute(stmt)
405406

tests/conftest.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
77

8-
from tests.model import Base, Ins
8+
from tests.model import Base, Ins, InsPks
99

1010
_async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True)
1111
_async_session = async_sessionmaker(_async_engine, autoflush=False, expire_on_commit=False)
@@ -29,3 +29,12 @@ async def create_test_model():
2929
async with _async_session.begin() as session:
3030
data = [Ins(name=f'name_{i}') for i in range(1, 10)]
3131
session.add_all(data)
32+
33+
34+
@pytest_asyncio.fixture
35+
async def create_test_model_pks():
36+
async with _async_session.begin() as session:
37+
data = [InsPks(id=i, name=f'name_{i}', sex='men') for i in range(1, 5)]
38+
session.add_all(data)
39+
data = [InsPks(id=i, name=f'name_{i}', sex='women') for i in range(6, 10)]
40+
session.add_all(data)

tests/model.py

+11
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,14 @@ class Ins(Base):
2020
del_flag: Mapped[bool] = mapped_column(default=False)
2121
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
2222
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)
23+
24+
25+
class InsPks(Base):
26+
__tablename__ = 'ins_pks'
27+
28+
id: Mapped[int] = mapped_column(primary_key=True, index=True)
29+
name: Mapped[str] = mapped_column(String(64))
30+
sex: Mapped[str] = mapped_column(String(16), primary_key=True, index=True)
31+
del_flag: Mapped[bool] = mapped_column(default=False)
32+
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
33+
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)

tests/schema.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@
55

66
class ModelTest(BaseModel):
77
name: str
8+
9+
10+
class ModelTestPks(BaseModel):
11+
id: int
12+
name: str
13+
sex: str

0 commit comments

Comments
 (0)