77from  django .db  import  IntegrityError , NotSupportedError 
88from  django .db .models  import  Count 
99from  django .db .models .aggregates  import  Aggregate , Variance 
10- from  django .db .models .expressions  import  Case , Col , Ref , Value , When 
10+ from  django .db .models .expressions  import  Case , Col , OrderBy ,  Ref , Value , When 
1111from  django .db .models .functions .comparison  import  Coalesce 
1212from  django .db .models .functions .math  import  Power 
1313from  django .db .models .lookups  import  IsNull 
@@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs):
3232        # A list of OrderBy objects for this query. 
3333        self .order_by_objs  =  None 
3434
35+     def  _unfold_column (self , col ):
36+         """ 
37+         Flatten a field by returning its target or by replacing dots with 
38+         GROUP_SEPARATOR for foreign fields. 
39+         """ 
40+         if  self .collection_name  ==  col .alias :
41+             return  col .target .column 
42+         # If this is a foreign field, replace the normal dot (.) with 
43+         # GROUP_SEPARATOR since FieldPath field names may not contain '.'. 
44+         return  f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column }  
45+ 
46+     def  _fold_columns (self , unfold_columns ):
47+         """ 
48+         Convert flat columns into a nested dictionary, grouping fields by 
49+         table name. 
50+         """ 
51+         result  =  defaultdict (dict )
52+         for  key  in  unfold_columns :
53+             value  =  f"$_id.{ key }  
54+             if  self .GROUP_SEPARATOR  in  key :
55+                 table , field  =  key .split (self .GROUP_SEPARATOR )
56+                 result [table ][field ] =  value 
57+             else :
58+                 result [key ] =  value 
59+         # Convert defaultdict to dict so it doesn't appear as 
60+         # "defaultdict(<CLASS 'dict'>, ..." in query logging. 
61+         return  dict (result )
62+ 
3563    def  _get_group_alias_column (self , expr , annotation_group_idx ):
3664        """Generate a dummy field for use in the ids fields in $group.""" 
3765        replacement  =  None 
@@ -42,11 +70,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4270            alias  =  f"__annotation_group{ next (annotation_group_idx )}  
4371            col  =  self ._get_column_from_expression (expr , alias )
4472            replacement  =  col 
45-         if  self .collection_name  ==  col .alias :
46-             return  col .target .column , replacement 
47-         # If this is a foreign field, replace the normal dot (.) with 
48-         # GROUP_SEPARATOR since FieldPath field names may not contain '.'. 
49-         return  f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column }  , replacement 
73+         return  self ._unfold_column (col ), replacement 
5074
5175    def  _get_column_from_expression (self , expr , alias ):
5276        """ 
@@ -186,17 +210,8 @@ def _build_aggregation_pipeline(self, ids, group):
186210        else :
187211            group ["_id" ] =  ids 
188212            pipeline .append ({"$group" : group })
189-             projected_fields  =  defaultdict (dict )
190-             for  key  in  ids :
191-                 value  =  f"$_id.{ key }  
192-                 if  self .GROUP_SEPARATOR  in  key :
193-                     table , field  =  key .split (self .GROUP_SEPARATOR )
194-                     projected_fields [table ][field ] =  value 
195-                 else :
196-                     projected_fields [key ] =  value 
197-             # Convert defaultdict to dict so it doesn't appear as 
198-             # "defaultdict(<CLASS 'dict'>, ..." in query logging. 
199-             pipeline .append ({"$addFields" : dict (projected_fields )})
213+             projected_fields  =  self ._fold_columns (ids )
214+             pipeline .append ({"$addFields" : projected_fields })
200215            if  "_id"  not  in projected_fields :
201216                pipeline .append ({"$unset" : "_id" })
202217        return  pipeline 
@@ -349,23 +364,30 @@ def build_query(self, columns=None):
349364        """Check if the query is supported and prepare a MongoQuery.""" 
350365        self .check_query ()
351366        query  =  self .query_class (self )
352-         query .lookup_pipeline  =  self .get_lookup_pipeline ()
353367        ordering_fields , sort_ordering , extra_fields  =  self ._get_ordering ()
354-         query .project_fields  =  self .get_project_fields (columns , ordering_fields )
355368        query .ordering  =  sort_ordering 
356-         # If columns is None, then get_project_fields() won't add 
357-         # ordering_fields to $project. Use $addFields (extra_fields) instead. 
358-         if  columns  is  None :
359-             extra_fields  +=  ordering_fields 
369+         if  self .query .combinator :
370+             if  not  getattr (self .connection .features , f"supports_select_{ self .query .combinator }  ):
371+                 raise  NotSupportedError (
372+                     f"{ self .query .combinator }  
373+                 )
374+             query .combinator_pipeline  =  self .get_combinator_queries ()
375+         else :
376+             query .project_fields  =  self .get_project_fields (columns , ordering_fields )
377+             # If columns is None, then get_project_fields() won't add 
378+             # ordering_fields to $project. Use $addFields (extra_fields) instead. 
379+             if  columns  is  None :
380+                 extra_fields  +=  ordering_fields 
381+             query .lookup_pipeline  =  self .get_lookup_pipeline ()
382+             where  =  self .get_where ()
383+             try :
384+                 expr  =  where .as_mql (self , self .connection ) if  where  else  {}
385+             except  FullResultSet :
386+                 query .mongo_query  =  {}
387+             else :
388+                 query .mongo_query  =  {"$expr" : expr }
360389        if  extra_fields :
361390            query .extra_fields  =  self .get_project_fields (extra_fields , force_expression = True )
362-         where  =  self .get_where ()
363-         try :
364-             expr  =  where .as_mql (self , self .connection ) if  where  else  {}
365-         except  FullResultSet :
366-             query .mongo_query  =  {}
367-         else :
368-             query .mongo_query  =  {"$expr" : expr }
369391        return  query 
370392
371393    def  get_columns (self ):
@@ -391,6 +413,9 @@ def project_field(column):
391413            if  hasattr (column , "target" ):
392414                # column is a Col. 
393415                target  =  column .target .column 
416+             # Handle Order By columns as refs columns. 
417+             elif  isinstance (column , OrderBy ) and  isinstance (column .expression , Ref ):
418+                 target  =  column .expression .refs 
394419            else :
395420                # column is a Transform in values()/values_list() that needs a 
396421                # name for $proj. 
@@ -412,6 +437,75 @@ def collection_name(self):
412437    def  collection (self ):
413438        return  self .connection .get_collection (self .collection_name )
414439
440+     def  get_combinator_queries (self ):
441+         parts  =  []
442+         compilers  =  [
443+             query .get_compiler (self .using , self .connection , self .elide_empty )
444+             for  query  in  self .query .combined_queries 
445+         ]
446+         main_query_columns  =  self .get_columns ()
447+         main_query_fields , _  =  zip (* main_query_columns , strict = True )
448+         for  compiler_  in  compilers :
449+             try :
450+                 # If the columns list is limited, then all combined queries 
451+                 # must have the same columns list. Set the selects defined on 
452+                 # the query on all combined queries, if not already set. 
453+                 if  not  compiler_ .query .values_select  and  self .query .values_select :
454+                     compiler_ .query  =  compiler_ .query .clone ()
455+                     compiler_ .query .set_values (
456+                         (
457+                             * self .query .extra_select ,
458+                             * self .query .values_select ,
459+                             * self .query .annotation_select ,
460+                         )
461+                     )
462+                 compiler_ .pre_sql_setup ()
463+                 columns  =  compiler_ .get_columns ()
464+                 parts .append ((compiler_ .build_query (columns ), compiler_ , columns ))
465+             except  EmptyResultSet :
466+                 # Omit the empty queryset with UNION. 
467+                 if  self .query .combinator  ==  "union" :
468+                     continue 
469+                 raise 
470+         # Raise EmptyResultSet if all the combinator queries are empty. 
471+         if  not  parts :
472+             raise  EmptyResultSet 
473+         # Make the combinator's stages. 
474+         combinator_pipeline  =  None 
475+         for  part , compiler_ , columns  in  parts :
476+             inner_pipeline  =  part .get_pipeline ()
477+             # Standardize result fields. 
478+             fields  =  {}
479+             # When a .count() is called, the main_query_field has length 1 
480+             # otherwise it has the same length as columns. 
481+             for  alias , (ref , expr ) in  zip (main_query_fields , columns , strict = False ):
482+                 if  isinstance (expr , Col ) and  expr .alias  !=  compiler_ .collection_name :
483+                     fields [expr .alias ] =  1 
484+                 else :
485+                     fields [alias ] =  f"${ ref }   if  alias  !=  ref  else  1 
486+             inner_pipeline .append ({"$project" : fields })
487+             # Combine query with the current combinator pipeline. 
488+             if  combinator_pipeline :
489+                 combinator_pipeline .append (
490+                     {"$unionWith" : {"coll" : compiler_ .collection_name , "pipeline" : inner_pipeline }}
491+                 )
492+             else :
493+                 combinator_pipeline  =  inner_pipeline 
494+         if  not  self .query .combinator_all :
495+             ids  =  {}
496+             for  alias , expr  in  main_query_columns :
497+                 # Unfold foreign fields. 
498+                 if  isinstance (expr , Col ) and  expr .alias  !=  self .collection_name :
499+                     ids [self ._unfold_column (expr )] =  expr .as_mql (self , self .connection )
500+                 else :
501+                     ids [alias ] =  f"${ alias }  
502+             combinator_pipeline .append ({"$group" : {"_id" : ids }})
503+             projected_fields  =  self ._fold_columns (ids )
504+             combinator_pipeline .append ({"$addFields" : projected_fields })
505+             if  "_id"  not  in projected_fields :
506+                 combinator_pipeline .append ({"$unset" : "_id" })
507+         return  combinator_pipeline 
508+ 
415509    def  get_lookup_pipeline (self ):
416510        result  =  []
417511        for  alias  in  tuple (self .query .alias_map ):
0 commit comments