Skip to content

Commit ab70090

Browse files
WaVEVtimgraham
authored andcommitted
INTPYTHON-751 Make query generation omit $expr unless required
1 parent 03b7c93 commit ab70090

30 files changed

+1515
-1123
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
1717
node.set_source_expressions([Case(condition), *source_expressions[1:]])
1818
else:
1919
node = self
20-
lhs_mql = process_lhs(node, compiler, connection)
20+
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
2121
if resolve_inner_expression:
2222
return lhs_mql
2323
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -39,9 +39,9 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3939
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
4040
)
4141
node.set_source_expressions([Case(condition), *source_expressions[1:]])
42-
inner_expression = process_lhs(node, compiler, connection)
42+
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
4343
else:
44-
lhs_mql = process_lhs(self, compiler, connection)
44+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
4545
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
4646
inner_expression = {
4747
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
@@ -51,7 +51,7 @@ def count(self, compiler, connection, resolve_inner_expression=False):
5151
return {"$sum": inner_expression}
5252
# If distinct=True or resolve_inner_expression=False, sum the size of the
5353
# set.
54-
lhs_mql = process_lhs(self, compiler, connection)
54+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
5555
# None shouldn't be counted, so subtract 1 if it's present.
5656
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
5757
return {"$add": [{"$size": lhs_mql}, exits_null]}
@@ -66,7 +66,7 @@ def stddev_variance(self, compiler, connection):
6666

6767

6868
def register_aggregates():
69-
Aggregate.as_mql = aggregate
70-
Count.as_mql = count
71-
StdDev.as_mql = stddev_variance
72-
Variance.as_mql = stddev_variance
69+
Aggregate.as_mql_expr = aggregate
70+
Count.as_mql_expr = count
71+
StdDev.as_mql_expr = stddev_variance
72+
Variance.as_mql_expr = stddev_variance

django_mongodb_backend/base.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
import os
44

5-
from django.core.exceptions import ImproperlyConfigured
5+
from bson import Decimal128
6+
from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured
67
from django.db import DEFAULT_DB_ALIAS
78
from django.db.backends.base.base import BaseDatabaseWrapper
89
from django.db.backends.utils import debug_transaction
@@ -96,6 +97,58 @@ class DatabaseWrapper(BaseDatabaseWrapper):
9697
}
9798
_connection_pools = {}
9899

100+
def _isnull_operator(field, is_null):
101+
if is_null:
102+
return {"$or": [{field: {"$exists": False}}, {field: None}]}
103+
return {"$and": [{field: {"$exists": True}}, {field: {"$ne": None}}]}
104+
105+
def _range_operator(a, b):
106+
conditions = []
107+
start, end = b
108+
if start is not None:
109+
conditions.append({a: {"$gte": b[0]}})
110+
if end is not None:
111+
conditions.append({a: {"$lte": b[1]}})
112+
if not conditions:
113+
raise FullResultSet
114+
if start is not None and end is not None:
115+
# Decimal128 can't be natively compared.
116+
if isinstance(start, Decimal128):
117+
start = start.to_decimal()
118+
if isinstance(end, Decimal128):
119+
end = end.to_decimal()
120+
if start > end:
121+
raise EmptyResultSet
122+
return {"$and": conditions}
123+
124+
def _regex_operator(field, regex, insensitive=False):
125+
options = "i" if insensitive else ""
126+
return {field: {"$regex": regex, "$options": options}}
127+
128+
mongo_operators = {
129+
"exact": lambda a, b: {a: b},
130+
"gt": lambda a, b: {a: {"$gt": b}},
131+
"gte": lambda a, b: {a: {"$gte": b}},
132+
# MongoDB considers null less than zero. Exclude null values to match
133+
# SQL behavior.
134+
"lt": lambda a, b: {"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator(a, False)]},
135+
"lte": lambda a, b: {
136+
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator(a, False)]
137+
},
138+
"in": lambda a, b: {a: {"$in": tuple(b)}},
139+
"isnull": _isnull_operator,
140+
"range": _range_operator,
141+
"iexact": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}$", insensitive=True),
142+
"startswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}"),
143+
"istartswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}", insensitive=True),
144+
"endswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"{b}$"),
145+
"iendswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"{b}$", insensitive=True),
146+
"contains": lambda a, b: DatabaseWrapper._regex_operator(a, b),
147+
"icontains": lambda a, b: DatabaseWrapper._regex_operator(a, b, insensitive=True),
148+
"regex": lambda a, b: DatabaseWrapper._regex_operator(a, b),
149+
"iregex": lambda a, b: DatabaseWrapper._regex_operator(a, b, insensitive=True),
150+
}
151+
99152
def _isnull_expr(field, is_null):
100153
mql = {
101154
"$or": [
@@ -112,7 +165,7 @@ def _regex_expr(field, regex_vals, insensitive=False):
112165
options = "i" if insensitive else ""
113166
return {"$regexMatch": {"input": field, "regex": regex, "options": options}}
114167

115-
mongo_operators = {
168+
mongo_expr_operators = {
116169
"exact": lambda a, b: {"$eq": [a, b]},
117170
"gt": lambda a, b: {"$gt": [a, b]},
118171
"gte": lambda a, b: {"$gte": [a, b]},

django_mongodb_backend/compiler.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ def _get_replace_expr(self, sub_expr, group, alias):
6969
if getattr(sub_expr, "distinct", False):
7070
# If the expression should return distinct values, use $addToSet to
7171
# deduplicate.
72-
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
72+
rhs = sub_expr.as_mql(
73+
self, self.connection, resolve_inner_expression=True, as_expr=True
74+
)
7375
group[alias] = {"$addToSet": rhs}
7476
replacing_expr = sub_expr.copy()
7577
replacing_expr.set_source_expressions([inner_column, None])
7678
else:
77-
group[alias] = sub_expr.as_mql(self, self.connection)
79+
group[alias] = sub_expr.as_mql(self, self.connection, as_expr=True)
7880
replacing_expr = inner_column
7981
# Count must return 0 rather than null.
8082
if isinstance(sub_expr, Count):
@@ -302,9 +304,7 @@ def _compound_searches_queries(self, search_replacements):
302304
search.as_mql(self, self.connection),
303305
{
304306
"$addFields": {
305-
result_col.as_mql(self, self.connection, as_path=True): {
306-
"$meta": score_function
307-
}
307+
result_col.as_mql(self, self.connection): {"$meta": score_function}
308308
}
309309
},
310310
]
@@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False):
334334
pipeline.extend(query.get_pipeline())
335335
# Remove the added subqueries.
336336
self.subqueries = []
337-
pipeline.append({"$match": {"$expr": having}})
337+
pipeline.append({"$match": having})
338338
self.aggregation_pipeline = pipeline
339339
self.annotations = {
340340
target: expr.replace_expressions(all_replacements)
@@ -481,11 +481,11 @@ def build_query(self, columns=None):
481481
query.lookup_pipeline = self.get_lookup_pipeline()
482482
where = self.get_where()
483483
try:
484-
expr = where.as_mql(self, self.connection) if where else {}
484+
match_mql = where.as_mql(self, self.connection) if where else {}
485485
except FullResultSet:
486486
query.match_mql = {}
487487
else:
488-
query.match_mql = {"$expr": expr}
488+
query.match_mql = match_mql
489489
if extra_fields:
490490
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
491491
query.subqueries = self.subqueries
@@ -643,7 +643,9 @@ def get_combinator_queries(self):
643643
for alias, expr in self.columns:
644644
# Unfold foreign fields.
645645
if isinstance(expr, Col) and expr.alias != self.collection_name:
646-
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
646+
ids[expr.alias][expr.target.column] = expr.as_mql(
647+
self, self.connection, as_expr=True
648+
)
647649
else:
648650
ids[alias] = f"${alias}"
649651
# Convert defaultdict to dict so it doesn't appear as
@@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707709
# For brevity/simplicity, project {"field_name": 1}
708710
# instead of {"field_name": "$field_name"}.
709711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection)
712+
else expr.as_mql(self, self.connection, as_expr=True)
711713
)
712714
except EmptyResultSet:
713715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
714716
value = (
715717
False if empty_result_set_value is NotImplemented else empty_result_set_value
716718
)
717-
fields[collection][name] = Value(value).as_mql(self, self.connection)
719+
fields[collection][name] = Value(value).as_mql(self, self.connection, as_expr=True)
718720
except FullResultSet:
719-
fields[collection][name] = Value(True).as_mql(self, self.connection)
721+
fields[collection][name] = Value(True).as_mql(self, self.connection, as_expr=True)
720722
# Annotations (stored in None) and the main collection's fields
721723
# should appear in the top-level of the fields dict.
722724
fields.update(fields.pop(None, {}))
@@ -739,10 +741,10 @@ def _get_ordering(self):
739741
idx = itertools.count(start=1)
740742
for order in self.order_by_objs or []:
741743
if isinstance(order.expression, Col):
742-
field_name = order.as_mql(self, self.connection).removeprefix("$")
744+
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
743745
fields.append((order.expression.target.column, order.expression))
744746
elif isinstance(order.expression, Ref):
745-
field_name = order.as_mql(self, self.connection).removeprefix("$")
747+
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
746748
else:
747749
field_name = f"__order{next(idx)}"
748750
fields.append((field_name, order.expression))
@@ -879,7 +881,7 @@ def execute_sql(self, result_type):
879881
)
880882
prepared = field.get_db_prep_save(value, connection=self.connection)
881883
if hasattr(value, "as_mql"):
882-
prepared = prepared.as_mql(self, self.connection)
884+
prepared = prepared.as_mql(self, self.connection, as_expr=True)
883885
values[field.column] = prepared
884886
try:
885887
criteria = self.build_query().match_mql

0 commit comments

Comments
 (0)