Skip to content

Commit bacff6d

Browse files
committed
Add AggregateFilter, StringgAgg.as_mql() as per
django/django@4b977a5
1 parent 3616b7a commit bacff6d

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
2-
from django.db.models.expressions import Case, Value, When
1+
from django.db import NotSupportedError
2+
from django.db.models.aggregates import (
3+
Aggregate,
4+
AggregateFilter,
5+
Count,
6+
StdDev,
7+
StringAgg,
8+
Variance,
9+
)
10+
from django.db.models.expressions import Case, Col, Value, When
311
from django.db.models.lookups import IsNull
412

513
from .query_utils import process_lhs
@@ -9,7 +17,11 @@
917

1018

1119
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
12-
if self.filter:
20+
# TODO: isinstance(self.filter, Col) works around failure of
21+
# aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
22+
# correct?
23+
if self.filter is not None and not isinstance(self.filter, Col):
24+
# Generate a CASE statement for this aggregate.
1325
node = self.copy()
1426
node.filter = None
1527
source_expressions = node.get_source_expressions()
@@ -24,6 +36,10 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2436
return {f"${operator}": lhs_mql}
2537

2638

39+
def aggregate_filter(self, compiler, connection):
40+
return self.condition.as_mql(compiler, connection)
41+
42+
2743
def count(self, compiler, connection, resolve_inner_expression=False):
2844
"""
2945
When resolve_inner_expression=True, return the MQL that resolves as a
@@ -65,8 +81,14 @@ def stddev_variance(self, compiler, connection):
6581
return aggregate(self, compiler, connection, operator=operator)
6682

6783

84+
def string_agg(self, compiler, connection): # noqa: ARG001
85+
raise NotSupportedError("StringAgg is not supported.")
86+
87+
6888
def register_aggregates():
6989
Aggregate.as_mql_expr = aggregate
90+
AggregateFilter.as_mql_expr = aggregate_filter
7091
Count.as_mql_expr = count
7192
StdDev.as_mql_expr = stddev_variance
93+
StringAgg.as_mql_expr = string_agg
7294
Variance.as_mql_expr = stddev_variance

0 commit comments

Comments
 (0)