1- from typing import Dict , List , Tuple , Union
1+ from typing import Dict , List , Tuple , Union , Iterable
2+ from itertools import chain
23
34import django
45
@@ -152,7 +153,7 @@ def bulk_insert(self, rows, return_model=False):
152153
153154 Arguments:
154155 rows:
155- An array of dictionaries, where each dictionary
156+ An iterable of dictionaries, where each dictionary
156157 describes the fields to insert.
157158
158159 return_model (default: False):
@@ -296,7 +297,7 @@ def upsert_and_get(
296297 def bulk_upsert (
297298 self ,
298299 conflict_target : List ,
299- rows : List [Dict ],
300+ rows : Iterable [Dict ],
300301 index_predicate : str = None ,
301302 return_model : bool = False ,
302303 ):
@@ -322,21 +323,21 @@ def bulk_upsert(
322323 A list of either the dicts of the rows upserted, including the pk or
323324 the models of the rows upserted
324325 """
325-
326- if not rows or len (rows ) <= 0 :
326+ is_empty = lambda r : all ([ False for _ in r ])
327+ if not rows or is_empty (rows ):
327328 return []
328329
329330 self .on_conflict (
330331 conflict_target , ConflictAction .UPDATE , index_predicate
331332 )
332333 return self .bulk_insert (rows , return_model )
333334
334- def _build_insert_compiler (self , rows : List [Dict ]):
335+ def _build_insert_compiler (self , rows : Iterable [Dict ]):
335336 """Builds the SQL compiler for a insert query.
336337
337338 Arguments:
338339 rows:
339- A list of dictionaries, where each entry
340+ An iterable of dictionaries, where each entry
340341 describes a record to insert.
341342
342343 Returns:
@@ -349,8 +350,10 @@ def _build_insert_compiler(self, rows: List[Dict]):
349350 # we need to be certain that each row specifies the exact same
350351 # amount of fields/columns
351352 objs = []
352- field_count = len (rows [0 ])
353- for index , row in enumerate (rows ):
353+ rows_iter = iter (rows )
354+ first_row = next (rows_iter )
355+ field_count = len (first_row )
356+ for index , row in enumerate (chain ([first_row ], rows_iter )):
354357 if field_count != len (row ):
355358 raise SuspiciousOperation (
356359 (
@@ -366,7 +369,7 @@ def _build_insert_compiler(self, rows: List[Dict]):
366369 self ._for_write = True
367370
368371 # get the fields to be used during update/insert
369- insert_fields , update_fields = self ._get_upsert_fields (rows [ 0 ] )
372+ insert_fields , update_fields = self ._get_upsert_fields (first_row )
370373
371374 # build a normal insert query
372375 query = PostgresInsertQuery (self .model )
@@ -597,7 +600,7 @@ def upsert_and_get(
597600 def bulk_upsert (
598601 self ,
599602 conflict_target : List ,
600- rows : List [Dict ],
603+ rows : Iterable [Dict ],
601604 index_predicate : str = None ,
602605 return_model : bool = False ,
603606 ):
0 commit comments