Skip to content

Commit cfaa026

Browse files
committed
Refactor aggregates and add coverage unit test
1 parent 3fb7386 commit cfaa026

File tree

4 files changed

+114
-6
lines changed

4 files changed

+114
-6
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
2121
agg_expression, *_ = self.get_source_expressions()
22+
lhs_mql = None
2223
if self.filter is not None:
2324
try:
2425
lhs_mql = self.filter.as_mql(compiler, connection, as_expr=True)
@@ -29,12 +30,13 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2930
# Skip rows that don't meet the criteria.
3031
default=Remove(),
3132
)
32-
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
3333
except FullResultSet:
34-
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
34+
# TODO: CHECK UNIT TEST REACH HERE.
35+
pass
3536
except EmptyResultSet:
36-
lhs_mql = Value(None).as_mql(compiler, connection, as_expr=True)
37-
else:
37+
# TODO: CHECK UNIT TEST REACH HERE.
38+
agg_expression = Remove()
39+
if lhs_mql is None:
3840
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
3941
if resolve_inner_expression:
4042
return lhs_mql
@@ -50,10 +52,11 @@ def count(self, compiler, connection, resolve_inner_expression=False):
5052
"""
5153
agg_expression, *_ = self.get_source_expressions()
5254
if not self.distinct or resolve_inner_expression:
55+
lhs_mql = None
5356
conditions = [IsNull(agg_expression, False)]
5457
if self.filter:
5558
try:
56-
inner_expression = self.filter.as_mql(compiler, connection, as_expr=True)
59+
lhs_mql = self.filter.as_mql(compiler, connection, as_expr=True)
5760
except NotSupportedError:
5861
conditions.append(self.filter.condition)
5962
condition = When(
@@ -71,7 +74,8 @@ def count(self, compiler, connection, resolve_inner_expression=False):
7174
# Skip rows that don't meet the criteria.
7275
default=Remove(),
7376
)
74-
lhs_mql = inner_expression.as_mql(compiler, connection, as_expr=True)
77+
if lhs_mql is None:
78+
lhs_mql = inner_expression.as_mql(compiler, connection, as_expr=True)
7579
if resolve_inner_expression:
7680
return lhs_mql
7781
return {"$sum": lhs_mql}

tests/aggregation_/__init__.py

Whitespace-only changes.

tests/aggregation_/models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from django.db import models
2+
3+
4+
class Author(models.Model):
5+
name = models.CharField(max_length=100)
6+
age = models.IntegerField()
7+
friends = models.ManyToManyField("self", blank=True)
8+
rating = models.FloatField(null=True)
9+
10+
def __str__(self):
11+
return self.name
12+
13+
14+
class Book(models.Model):
15+
isbn = models.CharField(max_length=9)
16+
name = models.CharField(max_length=255)
17+
pages = models.IntegerField()
18+
rating = models.FloatField()
19+
price = models.DecimalField(decimal_places=2, max_digits=6)
20+
authors = models.ManyToManyField(Author)
21+
contact = models.ForeignKey(Author, models.CASCADE, related_name="book_contact_set")
22+
pubdate = models.DateField()
23+
24+
def __str__(self):
25+
return self.name

tests/aggregation_/tests.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import datetime
2+
from decimal import Decimal
3+
4+
from django.db.models import (
5+
Count,
6+
Max,
7+
Q,
8+
)
9+
from django.test import TestCase
10+
11+
from .models import Author, Book
12+
13+
14+
class FilteredAggregateTests(TestCase):
15+
@classmethod
16+
def setUpTestData(cls):
17+
cls.a1 = Author.objects.create(name="test", age=40)
18+
cls.a2 = Author.objects.create(name="test2", age=60)
19+
cls.a3 = Author.objects.create(name="test3", age=40)
20+
cls.b1 = Book.objects.create(
21+
isbn="159059725",
22+
name="The Definitive Guide to Django: Web Development Done Right",
23+
pages=447,
24+
rating=4.5,
25+
price=Decimal("30.00"),
26+
contact=cls.a1,
27+
pubdate=datetime.date(2007, 12, 6),
28+
)
29+
cls.b2 = Book.objects.create(
30+
isbn="067232959",
31+
name="Sams Teach Yourself Django in 24 Hours",
32+
pages=528,
33+
rating=3.0,
34+
price=Decimal("30.00"),
35+
contact=cls.a2,
36+
pubdate=datetime.date(2008, 3, 3),
37+
)
38+
cls.b3 = Book.objects.create(
39+
isbn="159059996",
40+
name="Practical Django Projects",
41+
pages=600,
42+
rating=40.5,
43+
price=Decimal("30.00"),
44+
contact=cls.a3,
45+
pubdate=datetime.date(2008, 6, 23),
46+
)
47+
cls.a1.friends.add(cls.a2)
48+
cls.a1.friends.add(cls.a3)
49+
cls.b1.authors.add(cls.a1)
50+
cls.b1.authors.add(cls.a3)
51+
cls.b2.authors.add(cls.a2)
52+
cls.b3.authors.add(cls.a3)
53+
54+
def test_filtered_aggregate_empty_condition_distinct(self):
55+
book = Book.objects.annotate(
56+
ages=Count("authors__age", filter=Q(authors__in=[]), distinct=True),
57+
).get(pk=self.b1.pk)
58+
self.assertEqual(book.ages, 0)
59+
aggregate = Book.objects.aggregate(max_rating=Max("rating", filter=Q(rating__in=[])))
60+
self.assertEqual(aggregate, {"max_rating": None})
61+
62+
def test_filtered_aggregate_full_condition(self):
63+
book = Book.objects.annotate(
64+
ages=Count(
65+
"authors__age",
66+
filter=~Q(pk__in=[]),
67+
),
68+
).get(pk=self.b1.pk)
69+
self.assertEqual(book.ages, 2)
70+
aggregate = Book.objects.aggregate(max_rating=Max("rating", filter=~Q(rating__in=[])))
71+
self.assertEqual(aggregate, {"max_rating": 40.5})
72+
73+
def test_filtered_aggregate_full_condition_distinct(self):
74+
book = Book.objects.annotate(
75+
ages=Count("authors__age", filter=~Q(authors__in=[]), distinct=True),
76+
).get(pk=self.b1.pk)
77+
self.assertEqual(book.ages, 1)
78+
aggregate = Book.objects.aggregate(max_rating=Max("rating", filter=~Q(rating__in=[])))
79+
self.assertEqual(aggregate, {"max_rating": 40.5})

0 commit comments

Comments
 (0)