77from django .db import IntegrityError , NotSupportedError
88from django .db .models import Count
99from django .db .models .aggregates import Aggregate , Variance
10- from django .db .models .expressions import Case , Col , Ref , Value , When
10+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
1111from django .db .models .functions .comparison import Coalesce
1212from django .db .models .functions .math import Power
1313from django .db .models .lookups import IsNull
@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
3232 # A list of OrderBy objects for this query.
3333 self .order_by_objs = None
3434
35+ def _unfold_column (self , col ):
36+ """
37+ Flatten a field by returning its target or by replacing dots with GROUP_SEPARATOR
38+ for foreign fields.
39+ """
40+ if self .collection_name == col .alias :
41+ return col .target .column
42+ # If this is a foreign field, replace the normal dot (.) with
43+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45+
46+ def _fold_columns (self , unfold_columns ):
47+ """
48+ Convert flat columns into a nested dictionary, grouping fields by table names.
49+ """
50+ result = defaultdict (dict )
51+ for key in unfold_columns :
52+ value = f"$_id.{ key } "
53+ if self .GROUP_SEPARATOR in key :
54+ table , field = key .split (self .GROUP_SEPARATOR )
55+ result [table ][field ] = value
56+ else :
57+ result [key ] = value
58+ # Convert defaultdict to dict so it doesn't appear as
59+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
60+ return dict (result )
61+
3562 def _get_group_alias_column (self , expr , annotation_group_idx ):
3663 """Generate a dummy field for use in the ids fields in $group."""
3764 replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4269 alias = f"__annotation_group{ next (annotation_group_idx )} "
4370 col = self ._get_column_from_expression (expr , alias )
4471 replacement = col
45- if self .collection_name == col .alias :
46- return col .target .column , replacement
47- # If this is a foreign field, replace the normal dot (.) with
48- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
72+ return self ._unfold_column (col ), replacement
5073
5174 def _get_column_from_expression (self , expr , alias ):
5275 """
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186209 else :
187210 group ["_id" ] = ids
188211 pipeline .append ({"$group" : group })
189- projected_fields = defaultdict (dict )
190- for key in ids :
191- value = f"$_id.{ key } "
192- if self .GROUP_SEPARATOR in key :
193- table , field = key .split (self .GROUP_SEPARATOR )
194- projected_fields [table ][field ] = value
195- else :
196- projected_fields [key ] = value
197- # Convert defaultdict to dict so it doesn't appear as
198- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199- pipeline .append ({"$addFields" : dict (projected_fields )})
212+ projected_fields = self ._fold_columns (ids )
213+ pipeline .append ({"$addFields" : projected_fields })
200214 if "_id" not in projected_fields :
201215 pipeline .append ({"$unset" : "_id" })
202216 return pipeline
@@ -349,23 +363,30 @@ def build_query(self, columns=None):
349363 """Check if the query is supported and prepare a MongoQuery."""
350364 self .check_query ()
351365 query = self .query_class (self )
352- query .lookup_pipeline = self .get_lookup_pipeline ()
353366 ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354- query .project_fields = self .get_project_fields (columns , ordering_fields )
355367 query .ordering = sort_ordering
356- # If columns is None, then get_project_fields() won't add
357- # ordering_fields to $project. Use $addFields (extra_fields) instead.
358- if columns is None :
359- extra_fields += ordering_fields
368+ if self .query .combinator :
369+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
370+ raise NotSupportedError (
371+ f"{ self .query .combinator } is not supported on this database backend."
372+ )
373+ query .combinator_pipeline = self .get_combinator_queries ()
374+ else :
375+ query .project_fields = self .get_project_fields (columns , ordering_fields )
376+ # If columns is None, then get_project_fields() won't add
377+ # ordering_fields to $project. Use $addFields (extra_fields) instead.
378+ if columns is None :
379+ extra_fields += ordering_fields
380+ query .lookup_pipeline = self .get_lookup_pipeline ()
381+ where = self .get_where ()
382+ try :
383+ expr = where .as_mql (self , self .connection ) if where else {}
384+ except FullResultSet :
385+ query .mongo_query = {}
386+ else :
387+ query .mongo_query = {"$expr" : expr }
360388 if extra_fields :
361389 query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362- where = self .get_where ()
363- try :
364- expr = where .as_mql (self , self .connection ) if where else {}
365- except FullResultSet :
366- query .mongo_query = {}
367- else :
368- query .mongo_query = {"$expr" : expr }
369390 return query
370391
371392 def get_columns (self ):
@@ -391,6 +412,9 @@ def project_field(column):
391412 if hasattr (column , "target" ):
392413 # column is a Col.
393414 target = column .target .column
415+ # Handle Order By columns as refs columns.
416+ elif isinstance (column , OrderBy ) and isinstance (column .expression , Ref ):
417+ target = column .expression .refs
394418 else :
395419 # column is a Transform in values()/values_list() that needs a
396420 # name for $proj.
@@ -412,6 +436,75 @@ def collection_name(self):
412436 def collection (self ):
413437 return self .connection .get_collection (self .collection_name )
414438
439+ def get_combinator_queries (self ):
440+ parts = []
441+ compilers = [
442+ query .get_compiler (self .using , self .connection , self .elide_empty )
443+ for query in self .query .combined_queries
444+ ]
445+ main_query_columns = self .get_columns ()
446+ main_query_fields , _ = zip (* main_query_columns , strict = True )
447+ for compiler_ in compilers :
448+ try :
449+ # If the columns list is limited, then all combined queries
450+ # must have the same columns list. Set the selects defined on
451+ # the query on all combined queries, if not already set.
452+ if not compiler_ .query .values_select and self .query .values_select :
453+ compiler_ .query = compiler_ .query .clone ()
454+ compiler_ .query .set_values (
455+ (
456+ * self .query .extra_select ,
457+ * self .query .values_select ,
458+ * self .query .annotation_select ,
459+ )
460+ )
461+ compiler_ .pre_sql_setup ()
462+ columns = compiler_ .get_columns ()
463+ parts .append ((compiler_ .build_query (columns ), compiler_ , columns ))
464+ except EmptyResultSet :
465+ # Omit the empty queryset with UNION.
466+ if self .query .combinator == "union" :
467+ continue
468+ raise
469+ # Raise EmptyResultSet if all the combinator queries are empty.
470+ if not parts :
471+ raise EmptyResultSet
472+ # Make the combinator's stages.
473+ combinator_pipeline = None
474+ for part , compiler_ , columns in parts :
475+ inner_pipeline = part .get_pipeline ()
476+ # Standardize result fields.
477+ fields = {}
478+ # When a .count() is called, the main_query_field has length 1
479+ # otherwise it has the same length as columns.
480+ for alias , (ref , expr ) in zip (main_query_fields , columns , strict = False ):
481+ if isinstance (expr , Col ) and expr .alias != compiler_ .collection_name :
482+ fields [expr .alias ] = 1
483+ else :
484+ fields [alias ] = f"${ ref } " if alias != ref else 1
485+ inner_pipeline .append ({"$project" : fields })
486+ # Combine query with the current combinator pipeline.
487+ if combinator_pipeline :
488+ combinator_pipeline .append (
489+ {"$unionWith" : {"coll" : compiler_ .collection_name , "pipeline" : inner_pipeline }}
490+ )
491+ else :
492+ combinator_pipeline = inner_pipeline
493+ if not self .query .combinator_all :
494+ ids = {}
495+ for alias , expr in main_query_columns :
496+ # Unfold foreign fields.
497+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
498+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
499+ else :
500+ ids [alias ] = f"${ alias } "
501+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
502+ projected_fields = self ._fold_columns (ids )
503+ combinator_pipeline .append ({"$addFields" : projected_fields })
504+ if "_id" not in projected_fields :
505+ combinator_pipeline .append ({"$unset" : "_id" })
506+ return combinator_pipeline
507+
415508 def get_lookup_pipeline (self ):
416509 result = []
417510 for alias in tuple (self .query .alias_map ):
0 commit comments