@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
3232        # A list of OrderBy objects for this query. 
3333        self .order_by_objs  =  None 
3434
35+     def  _unfold_column (self , col ):
36+         """ 
37+         Flattens a field by returning its target or by replacing dots with GROUP_SEPARATOR 
38+         for foreign fields. 
39+         """ 
40+         if  self .collection_name  ==  col .alias :
41+             return  col .target .column 
42+         # If this is a foreign field, replace the normal dot (.) with 
43+         # GROUP_SEPARATOR since FieldPath field names may not contain '.'. 
44+         return  f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column }  
45+ 
46+     def  _fold_columns (self , unfold_columns ):
47+         """ 
48+         Converts flat columns into a nested dictionary, grouping fields by table names. 
49+         """ 
50+         result  =  defaultdict (dict )
51+         for  key  in  unfold_columns :
52+             value  =  f"$_id.{ key }  
53+             if  self .GROUP_SEPARATOR  in  key :
54+                 table , field  =  key .split (self .GROUP_SEPARATOR )
55+                 result [table ][field ] =  value 
56+             else :
57+                 result [key ] =  value 
58+         # Convert defaultdict to dict so it doesn't appear as 
59+         # "defaultdict(<CLASS 'dict'>, ..." in query logging. 
60+         return  dict (result )
61+ 
3562    def  _get_group_alias_column (self , expr , annotation_group_idx ):
3663        """Generate a dummy field for use in the ids fields in $group.""" 
3764        replacement  =  None 
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4269            alias  =  f"__annotation_group{ next (annotation_group_idx )}  
4370            col  =  self ._get_column_from_expression (expr , alias )
4471            replacement  =  col 
45-         if  self .collection_name  ==  col .alias :
46-             return  col .target .column , replacement 
47-         # If this is a foreign field, replace the normal dot (.) with 
48-         # GROUP_SEPARATOR since FieldPath field names may not contain '.'. 
49-         return  f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column }  , replacement 
72+         return  self ._unfold_column (col ), replacement 
5073
5174    def  _get_column_from_expression (self , expr , alias ):
5275        """ 
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186209        else :
187210            group ["_id" ] =  ids 
188211            pipeline .append ({"$group" : group })
189-             projected_fields  =  defaultdict (dict )
190-             for  key  in  ids :
191-                 value  =  f"$_id.{ key }  
192-                 if  self .GROUP_SEPARATOR  in  key :
193-                     table , field  =  key .split (self .GROUP_SEPARATOR )
194-                     projected_fields [table ][field ] =  value 
195-                 else :
196-                     projected_fields [key ] =  value 
197-             # Convert defaultdict to dict so it doesn't appear as 
198-             # "defaultdict(<CLASS 'dict'>, ..." in query logging. 
199-             pipeline .append ({"$addFields" : dict (projected_fields )})
212+             projected_fields  =  self ._fold_columns (ids )
213+             pipeline .append ({"$addFields" : projected_fields })
200214            if  "_id"  not  in projected_fields :
201215                pipeline .append ({"$unset" : "_id" })
202216        return  pipeline 
@@ -453,8 +467,7 @@ def get_combinator_queries(self):
453467                parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
454468
455469            except  EmptyResultSet :
456-                 # Omit the empty queryset with UNION and with DIFFERENCE if the 
457-                 # first queryset is nonempty. 
470+                 # Omit the empty queryset with UNION. 
458471                if  self .query .combinator  ==  "union" :
459472                    continue 
460473                raise 
@@ -470,25 +483,14 @@ def get_combinator_queries(self):
470483                if  not  self .query .combinator_all :
471484                    ids  =  {}
472485                    for  alias , expr  in  main_query_columns :
473-                         collection  =  expr .alias  if  isinstance (expr , Col ) else  None 
474-                         if  collection  and  collection  !=  self .collection_name :
475-                             ids [
476-                                 f"{ expr .alias } { self .GROUP_SEPARATOR } { expr .target .column }  
477-                             ] =  expr .as_mql (self , self .connection )
486+                         # Unfold foreign fields. 
487+                         if  isinstance (expr , Col ) and  expr .alias  !=  self .collection_name :
488+                             ids [self ._unfold_column (expr )] =  expr .as_mql (self , self .connection )
478489                        else :
479490                            ids [alias ] =  f"${ alias }  
480491                    combinator_pipeline .append ({"$group" : {"_id" : ids }})
481-                     projected_fields  =  defaultdict (dict )
482-                     for  key  in  ids :
483-                         value  =  f"$_id.{ key }  
484-                         if  self .GROUP_SEPARATOR  in  key :
485-                             table , field  =  key .split (self .GROUP_SEPARATOR )
486-                             projected_fields [table ][field ] =  value 
487-                         else :
488-                             projected_fields [key ] =  value 
489-                     # Convert defaultdict to dict so it doesn't appear as 
490-                     # "defaultdict(<CLASS 'dict'>, ..." in query logging. 
491-                     combinator_pipeline .append ({"$addFields" : dict (projected_fields )})
492+                     projected_fields  =  self ._fold_columns (ids )
493+                     combinator_pipeline .append ({"$addFields" : projected_fields })
492494                    if  "_id"  not  in projected_fields :
493495                        combinator_pipeline .append ({"$unset" : "_id" })
494496        else :
0 commit comments