|
22 | 22 | "lte": "__le__",
|
23 | 23 | }
|
24 | 24 |
|
25 |
| -MODEL = typing.TypeVar("MODEL", bound="Model") |
26 |
| - |
27 | 25 |
|
28 | 26 | def _update_auto_now_fields(values, fields):
|
29 | 27 | for key, value in fields.items():
|
@@ -468,38 +466,34 @@ async def update(self, **kwargs) -> None:
|
468 | 466 | await self.database.execute(expr)
|
469 | 467 |
|
470 | 468 | async def bulk_update(
|
471 |
| - self, objs: typing.List[MODEL], fields: typing.List[str] |
| 469 | + self, objs: typing.List["Model"], fields: typing.List[str] |
472 | 470 | ) -> None:
|
473 | 471 | fields = {
|
474 | 472 | key: field.validator
|
475 | 473 | for key, field in self.model_cls.fields.items()
|
476 | 474 | if key in fields
|
477 | 475 | }
|
478 | 476 | validator = typesystem.Schema(fields=fields)
|
| 477 | + objs = [ |
| 478 | + { |
| 479 | + key: _convert_value(value) |
| 480 | + for key, value in obj.__dict__.items() |
| 481 | + if key in fields |
| 482 | + } |
| 483 | + for obj in objs |
| 484 | + ] |
479 | 485 | new_objs = [
|
480 |
| - _update_auto_now_fields(validator.validate(value), self.model_cls.fields) |
481 |
| - for value in [ |
482 |
| - { |
483 |
| - key: _convert_value(value) |
484 |
| - for key, value in obj.__dict__.items() |
485 |
| - if key in fields |
486 |
| - } |
487 |
| - for obj in objs |
488 |
| - ] |
| 486 | + _update_auto_now_fields(validator.validate(obj), self.model_cls.fields) |
| 487 | + for obj in objs |
489 | 488 | ]
|
490 |
| - expr = ( |
491 |
| - self.table.update() |
492 |
| - .where( |
493 |
| - getattr(self.table.c, self.pkname) == sqlalchemy.bindparam(self.pkname) |
494 |
| - ) |
495 |
| - .values( |
496 |
| - { |
497 |
| - field: sqlalchemy.bindparam(field) |
498 |
| - for obj in new_objs |
499 |
| - for field in obj.keys() |
500 |
| - } |
501 |
| - ) |
502 |
| - ) |
| 489 | + pk_column = getattr(self.table.c, self.pkname) |
| 490 | + expr = self.table.update().where(pk_column == sqlalchemy.bindparam(self.pkname)) |
| 491 | + kwargs = { |
| 492 | + field: sqlalchemy.bindparam(field) |
| 493 | + for obj in new_objs |
| 494 | + for field in obj.keys() |
| 495 | + } |
| 496 | + expr = expr.values(kwargs) |
503 | 497 | pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs]
|
504 | 498 | joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)]
|
505 | 499 | await self.database.execute_many(str(expr), joined_list)
|
|
0 commit comments