Skip to content

Commit 08a3d1a

Browse files
authored
Merge pull request #83 from bigsassy/master
Ensure bulk_upsert accepts any iterable for rows
2 parents 7cf6850 + c692681 commit 08a3d1a

File tree

3 files changed

+73
-11
lines changed

3 files changed

+73
-11
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ dist/
2424
# Ignore temporary tox environments
2525
.tox/
2626
.pytest_cache/
27+
28+
# Ignore PyCharm / IntelliJ files
29+
.idea/

psqlextra/manager/manager.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, List, Tuple, Union
1+
from typing import Dict, List, Tuple, Union, Iterable
2+
from itertools import chain
23

34
import django
45

@@ -152,7 +153,7 @@ def bulk_insert(self, rows, return_model=False):
152153
153154
Arguments:
154155
rows:
155-
An array of dictionaries, where each dictionary
156+
An iterable of dictionaries, where each dictionary
156157
describes the fields to insert.
157158
158159
return_model (default: False):
@@ -296,7 +297,7 @@ def upsert_and_get(
296297
def bulk_upsert(
297298
self,
298299
conflict_target: List,
299-
rows: List[Dict],
300+
rows: Iterable[Dict],
300301
index_predicate: str = None,
301302
return_model: bool = False,
302303
):
@@ -322,21 +323,21 @@ def bulk_upsert(
322323
A list of either the dicts of the rows upserted, including the pk or
323324
the models of the rows upserted
324325
"""
325-
326-
if not rows or len(rows) <= 0:
326+
is_empty = lambda r: all([False for _ in r])
327+
if not rows or is_empty(rows):
327328
return []
328329

329330
self.on_conflict(
330331
conflict_target, ConflictAction.UPDATE, index_predicate
331332
)
332333
return self.bulk_insert(rows, return_model)
333334

334-
def _build_insert_compiler(self, rows: List[Dict]):
335+
def _build_insert_compiler(self, rows: Iterable[Dict]):
335336
"""Builds the SQL compiler for a insert query.
336337
337338
Arguments:
338339
rows:
339-
A list of dictionaries, where each entry
340+
An iterable of dictionaries, where each entry
340341
describes a record to insert.
341342
342343
Returns:
@@ -349,8 +350,10 @@ def _build_insert_compiler(self, rows: List[Dict]):
349350
# we need to be certain that each row specifies the exact same
350351
# amount of fields/columns
351352
objs = []
352-
field_count = len(rows[0])
353-
for index, row in enumerate(rows):
353+
rows_iter = iter(rows)
354+
first_row = next(rows_iter)
355+
field_count = len(first_row)
356+
for index, row in enumerate(chain([first_row], rows_iter)):
354357
if field_count != len(row):
355358
raise SuspiciousOperation(
356359
(
@@ -366,7 +369,7 @@ def _build_insert_compiler(self, rows: List[Dict]):
366369
self._for_write = True
367370

368371
# get the fields to be used during update/insert
369-
insert_fields, update_fields = self._get_upsert_fields(rows[0])
372+
insert_fields, update_fields = self._get_upsert_fields(first_row)
370373

371374
# build a normal insert query
372375
query = PostgresInsertQuery(self.model)
@@ -597,7 +600,7 @@ def upsert_and_get(
597600
def bulk_upsert(
598601
self,
599602
conflict_target: List,
600-
rows: List[Dict],
603+
rows: Iterable[Dict],
601604
index_predicate: str = None,
602605
return_model: bool = False,
603606
):

tests/test_upsert.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,59 @@ def test_bulk_upsert_return_models():
146146
for index, obj in enumerate(objs, 1):
147147
assert isinstance(obj, model)
148148
assert obj.id == index
149+
150+
151+
def test_bulk_upsert_accepts_getitem_iterable():
152+
"""Tests whether an iterable only implementing
153+
the __getitem__ method works correctly."""
154+
155+
class GetItemIterable:
156+
def __init__(self, items):
157+
self.items = items
158+
def __getitem__(self, key):
159+
return self.items[key]
160+
161+
model = get_fake_model(
162+
{
163+
"id": models.BigAutoField(primary_key=True),
164+
"name": models.CharField(max_length=255, unique=True),
165+
}
166+
)
167+
168+
rows = GetItemIterable([dict(name="John Smith"), dict(name="Jane Doe")])
169+
170+
objs = model.objects.bulk_upsert(
171+
conflict_target=["name"], rows=rows, return_model=True
172+
)
173+
174+
for index, obj in enumerate(objs, 1):
175+
assert isinstance(obj, model)
176+
assert obj.id == index
177+
178+
179+
def test_bulk_upsert_accepts_iter_iterable():
180+
"""Tests whether an iterable only implementing
181+
the __iter__ method works correctly."""
182+
183+
class IterIterable:
184+
def __init__(self, items):
185+
self.items = items
186+
def __iter__(self):
187+
return iter(self.items)
188+
189+
model = get_fake_model(
190+
{
191+
"id": models.BigAutoField(primary_key=True),
192+
"name": models.CharField(max_length=255, unique=True),
193+
}
194+
)
195+
196+
rows = IterIterable([dict(name="John Smith"), dict(name="Jane Doe")])
197+
198+
objs = model.objects.bulk_upsert(
199+
conflict_target=["name"], rows=rows, return_model=True
200+
)
201+
202+
for index, obj in enumerate(objs, 1):
203+
assert isinstance(obj, model)
204+
assert obj.id == index

0 commit comments

Comments
 (0)