Skip to content

Commit 4130665

Browse files
committed
refactor.
1 parent 19be82a commit 4130665

File tree

5 files changed

+183
-114
lines changed

5 files changed

+183
-114
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Exists,
1515
ExpressionList,
1616
ExpressionWrapper,
17-
Func,
1817
NegatedExpression,
1918
OrderBy,
2019
RawSQL,
@@ -25,12 +24,9 @@
2524
Value,
2625
When,
2726
)
28-
from django.db.models.fields.json import KeyTransform
2927
from django.db.models.sql import Query
3028

31-
from django_mongodb_backend.fields.array import Array
32-
33-
from ..query_utils import is_direct_value, process_lhs
29+
from ..query_utils import process_lhs
3430

3531

3632
def case(self, compiler, connection, as_path=False):
@@ -235,34 +231,14 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
235231
return value
236232

237233

238-
@staticmethod
239-
def _is_constant_value(value):
240-
if isinstance(value, list | Array):
241-
iterable = value.get_source_expressions() if isinstance(value, Array) else value
242-
return all(_is_constant_value(e) for e in iterable)
243-
if is_direct_value(value):
244-
return True
245-
return isinstance(value, Func | Value) and not (
246-
value.contains_aggregate
247-
or value.contains_over_clause
248-
or value.contains_column_references
249-
or value.contains_subquery
250-
)
251-
252-
253-
@staticmethod
254-
def _is_simple_column(lhs):
255-
while isinstance(lhs, KeyTransform):
256-
if "." in getattr(lhs, "key_name", ""):
257-
return False
258-
lhs = lhs.lhs
259-
col = lhs.source if isinstance(lhs, Ref) else lhs
260-
# Foreign columns from parent cannot be addressed as single match
261-
return isinstance(col, Col) and col.alias is not None
262-
234+
def base_expression(self, compiler, connection, as_path=False):
235+
if as_path and getattr(self, "is_simple_expression", lambda: False)():
236+
self.is_simple_expression()
237+
if hasattr(self, "as_mql_path"):
238+
return self.as_mql_path(compiler, connection)
263239

264-
def _is_simple_expression(self):
265-
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
240+
expr = self.as_mql_expr(compiler, connection)
241+
return {"$expr": expr} if as_path else expr
266242

267243

268244
def register_expressions():
@@ -283,6 +259,4 @@ def register_expressions():
283259
Subquery.as_mql = subquery
284260
When.as_mql = when
285261
Value.as_mql = value
286-
BaseExpression.is_simple_expression = _is_simple_expression
287-
BaseExpression.is_simple_column = _is_simple_column
288-
BaseExpression.is_constant_value = _is_constant_value
262+
BaseExpression.as_mql = base_expression

django_mongodb_backend/fields/json.py

Lines changed: 95 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
KeyTransformNumericLookupMixin,
1717
)
1818

19-
from ..lookups import builtin_lookup
20-
from ..query_utils import process_lhs, process_rhs
19+
from ..lookups import builtin_lookup_expr, builtin_lookup_path
20+
from ..query_utils import is_simple_column, is_simple_expression, process_lhs, process_rhs
2121

2222

2323
def build_json_mql_path(lhs, key_transforms, as_path=False):
@@ -71,28 +71,35 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False):
7171
return result
7272

7373

74+
def has_key_check_simple_expression(self):
75+
return is_simple_expression(self) and all("." not in v for v in self.rhs)
76+
77+
7478
def has_key_lookup(self, compiler, connection, as_path=False):
7579
"""Return MQL to check for the existence of a key."""
7680
rhs = self.rhs
7781
if not isinstance(rhs, (list, tuple)):
7882
rhs = [rhs]
79-
as_path = as_path and self.is_simple_expression() and all("." not in v for v in rhs)
8083
lhs = process_lhs(self, compiler, connection, as_path=as_path)
8184
paths = []
8285
# Transform any "raw" keys into KeyTransforms to allow consistent handling
8386
# in the code that follows.
84-
8587
for key in rhs:
8688
rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs)
8789
paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path))
8890
keys = []
8991
for path in paths:
9092
keys.append(_has_key_predicate(path, lhs, as_path=as_path))
9193

92-
result = keys[0] if self.mongo_operator is None else {self.mongo_operator: keys}
93-
if not as_path:
94-
result = {"$expr": result}
95-
return result
94+
return keys[0] if self.mongo_operator is None else {self.mongo_operator: keys}
95+
96+
97+
def has_key_lookup_path(self, compiler, connection):
98+
return has_key_lookup(self, compiler, connection, as_path=True)
99+
100+
101+
def has_key_lookup_expr(self, compiler, connection):
102+
return has_key_lookup(self, compiler, connection, as_path=False)
96103

97104

98105
_process_rhs = JSONExact.process_rhs
@@ -107,7 +114,7 @@ def json_exact_process_rhs(self, compiler, connection):
107114
)
108115

109116

110-
def key_transform(self, compiler, connection, as_path=False):
117+
def key_transform_path(self, compiler, connection):
111118
"""
112119
Return MQL for this KeyTransform (JSON path).
113120
@@ -121,13 +128,19 @@ def key_transform(self, compiler, connection, as_path=False):
121128
while isinstance(previous, KeyTransform):
122129
key_transforms.insert(0, previous.key_name)
123130
previous = previous.lhs
124-
if as_path and self.is_simple_column(self.lhs):
125-
lhs_mql = previous.as_mql(compiler, connection, as_path=True)
126-
return build_json_mql_path(lhs_mql, key_transforms, as_path=True)
131+
# Collect all key transforms in order.
132+
lhs_mql = previous.as_mql(compiler, connection, as_path=True)
133+
return build_json_mql_path(lhs_mql, key_transforms, as_path=True)
134+
135+
136+
def key_transform_expr(self, compiler, connection):
137+
key_transforms = [self.key_name]
138+
previous = self.lhs
139+
while isinstance(previous, KeyTransform):
140+
key_transforms.insert(0, previous.key_name)
141+
previous = previous.lhs
127142
# Collect all key transforms in order.
128143
lhs_mql = previous.as_mql(compiler, connection, as_path=False)
129-
if as_path:
130-
return {"$expr": build_json_mql_path(lhs_mql, key_transforms, as_path=False)}
131144
return build_json_mql_path(lhs_mql, key_transforms, as_path=False)
132145

133146

@@ -137,7 +150,7 @@ def key_transform_in(self, compiler, connection, as_path=False):
137150
set of specified values (rhs).
138151
"""
139152
if as_path and self.is_simple_expression():
140-
return builtin_lookup(self, compiler, connection, as_path=True)
153+
return builtin_lookup_path(self, compiler, connection)
141154

142155
lhs_mql = process_lhs(self, compiler, connection)
143156
# Traverse to the root column.
@@ -154,7 +167,24 @@ def key_transform_in(self, compiler, connection, as_path=False):
154167
return expr
155168

156169

157-
def key_transform_is_null(self, compiler, connection, as_path=False):
170+
def key_transform_in_path(self, compiler, connection):
171+
return builtin_lookup_path(self, compiler, connection)
172+
173+
174+
def key_transform_in_expr(self, compiler, connection):
175+
lhs_mql = process_lhs(self, compiler, connection)
176+
# Traverse to the root column.
177+
previous = self.lhs
178+
while isinstance(previous, KeyTransform):
179+
previous = previous.lhs
180+
root_column = previous.as_mql(compiler, connection)
181+
value = process_rhs(self, compiler, connection)
182+
# Construct the expression to check if lhs_mql values are in rhs values.
183+
expr = connection.mongo_operators_expr[self.lookup_name](lhs_mql, value)
184+
return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]}
185+
186+
187+
def key_transform_is_null_path(self, compiler, connection):
158188
"""
159189
Return MQL to check the nullability of a key.
160190
@@ -164,63 +194,77 @@ def key_transform_is_null(self, compiler, connection, as_path=False):
164194
165195
Reference: https://code.djangoproject.com/ticket/32252
166196
"""
167-
if as_path and self.is_simple_expression():
168-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
169-
rhs_mql = process_rhs(self, compiler, connection)
170-
return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True)
171-
# Get the root column.
197+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
198+
rhs_mql = process_rhs(self, compiler, connection, as_path=True)
199+
return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True)
200+
201+
202+
def key_transform_is_null_expr(self, compiler, connection):
172203
previous = self.lhs
173204
while isinstance(previous, KeyTransform):
174205
previous = previous.lhs
175206
root_column = previous.as_mql(compiler, connection)
176-
expr = _has_key_predicate(lhs_mql, root_column, negated=rhs_mql)
177-
if as_path:
178-
return {"$expr": expr}
179-
return expr
207+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
208+
rhs_mql = process_rhs(self, compiler, connection)
209+
return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql)
180210

181211

182-
def key_transform_numeric_lookup_mixin(self, compiler, connection, as_path=False):
212+
def key_transform_numeric_lookup_mixin_path(self, compiler, connection):
183213
"""
184214
Return MQL to check if the field exists (i.e., is not "missing" or "null")
185215
and that the field matches the given numeric lookup expression.
186216
"""
187-
if as_path and self.is_simple_expression():
188-
return builtin_lookup(self, compiler, connection, as_path=True)
217+
return builtin_lookup_path(self, compiler, connection)
218+
189219

220+
def key_transform_numeric_lookup_mixin_expr(self, compiler, connection):
221+
"""
222+
Return MQL to check if the field exists (i.e., is not "missing" or "null")
223+
and that the field matches the given numeric lookup expression.
224+
"""
190225
lhs = process_lhs(self, compiler, connection, as_path=False)
191-
expr = builtin_lookup(self, compiler, connection, as_path=False)
226+
expr = builtin_lookup_expr(self, compiler, connection)
192227
# Check if the type of lhs is not "missing" or "null".
193228
not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}}
194-
expr = {"$and": [expr, not_missing_or_null]}
195-
if as_path:
196-
return {"$expr": expr}
197-
return expr
229+
return {"$and": [expr, not_missing_or_null]}
198230

199231

200-
def key_transform_exact(self, compiler, connection, as_path=False):
201-
if as_path and self.is_simple_expression():
202-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
203-
return {
204-
"$and": [
205-
builtin_lookup(self, compiler, connection, as_path=True),
206-
_has_key_predicate(lhs_mql, None, as_path=True),
207-
]
208-
}
209-
if as_path:
210-
return {"$expr": builtin_lookup(self, compiler, connection, as_path=False)}
211-
return builtin_lookup(self, compiler, connection, as_path=False)
232+
def key_transform_exact_path(self, compiler, connection):
233+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
234+
return {
235+
"$and": [
236+
builtin_lookup_path(self, compiler, connection),
237+
_has_key_predicate(lhs_mql, None, as_path=True),
238+
]
239+
}
240+
241+
242+
def key_transform_exact_expr(self, compiler, connection):
243+
return builtin_lookup_expr(self, compiler, connection)
212244

213245

214246
def register_json_field():
215247
ContainedBy.as_mql = contained_by
216248
DataContains.as_mql = data_contains
217249
HasAnyKeys.mongo_operator = "$or"
218250
HasKey.mongo_operator = None
219-
HasKeyLookup.as_mql = has_key_lookup
251+
# HasKeyLookup.as_mql = has_key_lookup
252+
HasKeyLookup.is_simple_expression = has_key_check_simple_expression
253+
HasKeyLookup.as_mql_path = has_key_lookup_path
254+
HasKeyLookup.as_mql_expr = has_key_lookup_expr
220255
HasKeys.mongo_operator = "$and"
221256
JSONExact.process_rhs = json_exact_process_rhs
222-
KeyTransform.as_mql = key_transform
223-
KeyTransformIn.as_mql = key_transform_in
224-
KeyTransformIsNull.as_mql = key_transform_is_null
225-
KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin
226-
KeyTransformExact.as_mql = key_transform_exact
257+
KeyTransform.is_simple_expression = is_simple_column
258+
# KeyTransform.as_mql = key_transform
259+
KeyTransform.as_mql_path = key_transform_path
260+
KeyTransform.as_mql_expr = key_transform_expr
261+
# KeyTransformIn.as_mql = key_transform_in
262+
KeyTransformIn.as_mql_path = key_transform_in_path
263+
KeyTransformIn.as_mql_expr = key_transform_in_expr
264+
# KeyTransformIsNull.as_mql = key_transform_is_null
265+
KeyTransformIsNull.as_mql_path = key_transform_is_null_path
266+
KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr
267+
KeyTransformNumericLookupMixin.as_mql_path = key_transform_numeric_lookup_mixin_path
268+
KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr
269+
KeyTransformExact.as_mql_expr = key_transform_exact_expr
270+
KeyTransformExact.as_mql_path = key_transform_exact_path

django_mongodb_backend/functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def func(self, compiler, connection, as_path=False):
125125
return {f"${operator}": lhs_mql}
126126

127127

128+
def func_path(self, compiler, connection): # noqa: ARG001
129+
raise NotSupportedError(f"{self} May need an as_mql_path() method.")
130+
131+
132+
def func_expr(self, compiler, connection):
133+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
134+
if self.function is None:
135+
raise NotSupportedError(f"{self} may need an as_mql() method.")
136+
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
137+
return {f"${operator}": lhs_mql}
138+
139+
128140
def left(self, compiler, connection, as_path=False):
129141
return self.get_substr().as_mql(compiler, connection, as_path=as_path)
130142

@@ -312,7 +324,8 @@ def register_functions():
312324
ConcatPair.as_mql = concat_pair
313325
Cot.as_mql = cot
314326
Extract.as_mql = extract
315-
Func.as_mql = func
327+
Func.as_mql_path = func_path
328+
Func.as_mql_expr = func_expr
316329
JSONArray.as_mql = process_lhs
317330
Left.as_mql = left
318331
Length.as_mql = length

0 commit comments

Comments
 (0)