Skip to content

Commit 19be82a

Browse files
committed
Refactor.
1 parent 08ef1f3 commit 19be82a

File tree

4 files changed

+32
-35
lines changed

4 files changed

+32
-35
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
from ..query_utils import is_direct_value, process_lhs
3434

3535

36-
# EXTRA IS TOTALLY IGNORED
37-
# shall check if we could optimize match here
3836
def case(self, compiler, connection, as_path=False):
3937
case_parts = []
4038
for case in self.cases:
@@ -94,21 +92,20 @@ def col_pairs(self, compiler, connection, as_path=False):
9492
return cols[0].as_mql(compiler, connection, as_path=as_path)
9593

9694

97-
def combined_expression(self, compiler, connection, **extra):
95+
def combined_expression(self, compiler, connection, as_path=False):
9896
expressions = [
99-
self.lhs.as_mql(compiler, connection, **extra),
100-
self.rhs.as_mql(compiler, connection, **extra),
97+
self.lhs.as_mql(compiler, connection, as_path=as_path),
98+
self.rhs.as_mql(compiler, connection, as_path=as_path),
10199
]
102100
return connection.ops.combine_expression(self.connector, expressions)
103101

104102

105-
def expression_wrapper(self, compiler, connection, **extra):
106-
return self.expression.as_mql(compiler, connection, **extra)
103+
def expression_wrapper(self, compiler, connection, as_path=False):
104+
return self.expression.as_mql(compiler, connection, as_path=as_path)
107105

108106

109-
def negated_expression(self, compiler, connection, **extra):
110-
# review
111-
return {"$not": expression_wrapper(self, compiler, connection, **extra)}
107+
def negated_expression(self, compiler, connection, as_path=False):
108+
return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)}
112109

113110

114111
def order_by(self, compiler, connection):
@@ -194,15 +191,14 @@ def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=Fal
194191
return expr
195192

196193

197-
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, **extra):
194+
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
198195
try:
199196
lhs_mql = subquery(
200197
self,
201198
compiler,
202199
connection,
203200
get_wrapping_pipeline=get_wrapping_pipeline,
204201
as_path=as_path,
205-
**extra,
206202
)
207203
except EmptyResultSet:
208204
return Value(False).as_mql(compiler, connection)
@@ -211,8 +207,8 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False
211207
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
212208

213209

214-
def when(self, compiler, connection, **extra):
215-
return self.condition.as_mql(compiler, connection, **extra)
210+
def when(self, compiler, connection, as_path=False):
211+
return self.condition.as_mql(compiler, connection, as_path=as_path)
216212

217213

218214
def value(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -244,9 +240,8 @@ def _is_constant_value(value):
244240
if isinstance(value, list | Array):
245241
iterable = value.get_source_expressions() if isinstance(value, Array) else value
246242
return all(_is_constant_value(e) for e in iterable)
247-
if isinstance(value, Value) or is_direct_value(value):
248-
v = value.value if isinstance(value, Value) else value
249-
return not isinstance(v, str) or "." not in v
243+
if is_direct_value(value):
244+
return True
250245
return isinstance(value, Func | Value) and not (
251246
value.contains_aggregate
252247
or value.contains_over_clause
@@ -266,7 +261,7 @@ def _is_simple_column(lhs):
266261
return isinstance(col, Col) and col.alias is not None
267262

268263

269-
def is_simple_expression(self):
264+
def _is_simple_expression(self):
270265
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
271266

272267

@@ -288,6 +283,6 @@ def register_expressions():
288283
Subquery.as_mql = subquery
289284
When.as_mql = when
290285
Value.as_mql = value
291-
BaseExpression.is_simple_expression = is_simple_expression
286+
BaseExpression.is_simple_expression = _is_simple_expression
292287
BaseExpression.is_simple_column = _is_simple_column
293288
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/fields/array.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,14 @@ def __init__(self, index, base_field, *args, **kwargs):
399399
self.base_field = base_field
400400

401401
def as_mql(self, compiler, connection, as_path=False):
402-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
403-
if as_path:
402+
if as_path and self.is_simple_column(self.lhs):
403+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
404404
return f"{lhs_mql}.{self.index}"
405-
return {"$arrayElemAt": [lhs_mql, self.index]}
405+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
406+
expr = {"$arrayElemAt": [lhs_mql, self.index]}
407+
if as_path:
408+
return {"$expr": expr}
409+
return expr
406410

407411
@property
408412
def output_field(self):

django_mongodb_backend/fields/json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False):
7373

7474
def has_key_lookup(self, compiler, connection, as_path=False):
7575
"""Return MQL to check for the existence of a key."""
76-
as_path = as_path and self.is_simple_expression()
77-
lhs = process_lhs(self, compiler, connection, as_path=as_path)
7876
rhs = self.rhs
7977
if not isinstance(rhs, (list, tuple)):
8078
rhs = [rhs]
79+
as_path = as_path and self.is_simple_expression() and all("." not in v for v in rhs)
80+
lhs = process_lhs(self, compiler, connection, as_path=as_path)
8181
paths = []
8282
# Transform any "raw" keys into KeyTransforms to allow consistent handling
8383
# in the code that follows.

django_mongodb_backend/functions.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,7 @@
6565
}
6666

6767

68-
# TODO: ALL THOSE FUNCTION MAY CHECK AS_EXPR OR AS_PATH=FALSE. JUST NEED TO REVIEW ALL THE
69-
# TEST THAT HAVE THOSE OPERATOR.
70-
71-
72-
def cast(self, compiler, connection, as_path=False): # noqa: ARG001
68+
def cast(self, compiler, connection, as_path=False):
7369
output_type = connection.data_types[self.output_field.get_internal_type()]
7470
lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0]
7571
if max_length := self.output_field.max_length:
@@ -82,6 +78,8 @@ def cast(self, compiler, connection, as_path=False): # noqa: ARG001
8278
if decimal_places := getattr(self.output_field, "decimal_places", None):
8379
lhs_mql = {"$trunc": [lhs_mql, decimal_places]}
8480

81+
if as_path:
82+
return {"$expr": lhs_mql}
8583
return lhs_mql
8684

8785

@@ -98,7 +96,7 @@ def concat_pair(self, compiler, connection, as_path=False):
9896

9997

10098
def cot(self, compiler, connection, as_path=False):
101-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
99+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
102100
if as_path:
103101
return {"$expr": {"$divide": [1, {"$tan": lhs_mql}]}}
104102
return {"$divide": [1, {"$tan": lhs_mql}]}
@@ -118,7 +116,7 @@ def extract(self, compiler, connection, as_path=False):
118116

119117

120118
def func(self, compiler, connection, as_path=False):
121-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
119+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
122120
if self.function is None:
123121
raise NotSupportedError(f"{self} may need an as_mql() method.")
124122
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
@@ -127,8 +125,8 @@ def func(self, compiler, connection, as_path=False):
127125
return {f"${operator}": lhs_mql}
128126

129127

130-
def left(self, compiler, connection, as_path=False): # noqa: ARG001
131-
return self.get_substr().as_mql(compiler, connection, as_path=False)
128+
def left(self, compiler, connection, as_path=False):
129+
return self.get_substr().as_mql(compiler, connection, as_path=as_path)
132130

133131

134132
def length(self, compiler, connection, as_path=False):
@@ -140,11 +138,11 @@ def length(self, compiler, connection, as_path=False):
140138
return expr
141139

142140

143-
def log(self, compiler, connection, as_path=False): # noqa: ARG001
141+
def log(self, compiler, connection, as_path=False):
144142
# This function is usually log(base, num) but on MongoDB it's log(num, base).
145143
clone = self.copy()
146144
clone.set_source_expressions(self.get_source_expressions()[::-1])
147-
return func(clone, compiler, connection)
145+
return func(clone, compiler, connection, as_path=as_path)
148146

149147

150148
def now(self, compiler, connection, as_path=False): # noqa: ARG001

0 commit comments

Comments
 (0)