@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
3232 # A list of OrderBy objects for this query.
3333 self .order_by_objs = None
3434
35+ def _unfold_column (self , col ):
36+ """
37+ Flattens a field by returning its target or by replacing dots with GROUP_SEPARATOR
38+ for foreign fields.
39+ """
40+ if self .collection_name == col .alias :
41+ return col .target .column
42+ # If this is a foreign field, replace the normal dot (.) with
43+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45+
46+ def _fold_columns (self , unfold_columns ):
47+ """
48+ Converts flat columns into a nested dictionary, grouping fields by table names.
49+ """
50+ result = defaultdict (dict )
51+ for key in unfold_columns :
52+ value = f"$_id.{ key } "
53+ if self .GROUP_SEPARATOR in key :
54+ table , field = key .split (self .GROUP_SEPARATOR )
55+ result [table ][field ] = value
56+ else :
57+ result [key ] = value
58+ # Convert defaultdict to dict so it doesn't appear as
59+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
60+ return dict (result )
61+
3562 def _get_group_alias_column (self , expr , annotation_group_idx ):
3663 """Generate a dummy field for use in the ids fields in $group."""
3764 replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4269 alias = f"__annotation_group{ next (annotation_group_idx )} "
4370 col = self ._get_column_from_expression (expr , alias )
4471 replacement = col
45- if self .collection_name == col .alias :
46- return col .target .column , replacement
47- # If this is a foreign field, replace the normal dot (.) with
48- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
72+ return self ._unfold_column (col ), replacement
5073
5174 def _get_column_from_expression (self , expr , alias ):
5275 """
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186209 else :
187210 group ["_id" ] = ids
188211 pipeline .append ({"$group" : group })
189- projected_fields = defaultdict (dict )
190- for key in ids :
191- value = f"$_id.{ key } "
192- if self .GROUP_SEPARATOR in key :
193- table , field = key .split (self .GROUP_SEPARATOR )
194- projected_fields [table ][field ] = value
195- else :
196- projected_fields [key ] = value
197- # Convert defaultdict to dict so it doesn't appear as
198- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199- pipeline .append ({"$addFields" : dict (projected_fields )})
212+ projected_fields = self ._fold_columns (ids )
213+ pipeline .append ({"$addFields" : projected_fields })
200214 if "_id" not in projected_fields :
201215 pipeline .append ({"$unset" : "_id" })
202216 return pipeline
@@ -453,8 +467,7 @@ def get_combinator_queries(self):
453467 parts .append ((compiler_ .build_query (columns ), compiler_ .collection_name ))
454468
455469 except EmptyResultSet :
456- # Omit the empty queryset with UNION and with DIFFERENCE if the
457- # first queryset is nonempty.
470+ # Omit the empty queryset with UNION.
458471 if self .query .combinator == "union" :
459472 continue
460473 raise
@@ -470,25 +483,14 @@ def get_combinator_queries(self):
470483 if not self .query .combinator_all :
471484 ids = {}
472485 for alias , expr in main_query_columns :
473- collection = expr .alias if isinstance (expr , Col ) else None
474- if collection and collection != self .collection_name :
475- ids [
476- f"{ expr .alias } { self .GROUP_SEPARATOR } { expr .target .column } "
477- ] = expr .as_mql (self , self .connection )
486+ # Unfold foreign fields.
487+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
488+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
478489 else :
479490 ids [alias ] = f"${ alias } "
480491 combinator_pipeline .append ({"$group" : {"_id" : ids }})
481- projected_fields = defaultdict (dict )
482- for key in ids :
483- value = f"$_id.{ key } "
484- if self .GROUP_SEPARATOR in key :
485- table , field = key .split (self .GROUP_SEPARATOR )
486- projected_fields [table ][field ] = value
487- else :
488- projected_fields [key ] = value
489- # Convert defaultdict to dict so it doesn't appear as
490- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
491- combinator_pipeline .append ({"$addFields" : dict (projected_fields )})
492+ projected_fields = self ._fold_columns (ids )
493+ combinator_pipeline .append ({"$addFields" : projected_fields })
492494 if "_id" not in projected_fields :
493495 combinator_pipeline .append ({"$unset" : "_id" })
494496 else :
0 commit comments