diff --git a/docs/changelog.rst b/docs/changelog.rst index ff2dd38c1..13eb6ccd7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -16,6 +16,7 @@ Development - BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry - Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778 - Fix use of $geoNear or $collStats in aggregate #2493 +- Fix use of $search or $vectorSearch in aggregate #2878 - BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface - BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+ - Fixed stacklevel of many warnings (to point places emitting the warning more accurately) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 2db97ddb7..f077ad83b 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1398,9 +1398,10 @@ def aggregate(self, pipeline, **kwargs): first_step = [] new_user_pipeline = [] for step_step in pipeline: - if "$geoNear" in step_step: - first_step.append(step_step) - elif "$collStats" in step_step: + if any( + el in step_step + for el in ["$geoNear", "$collStats", "$search", "$vectorSearch"] + ): first_step.append(step_step) else: new_user_pipeline.append(step_step) diff --git a/tests/queryset/test_queryset_aggregation.py b/tests/queryset/test_queryset_aggregation.py index 7e390e35a..283680ac0 100644 --- a/tests/queryset/test_queryset_aggregation.py +++ b/tests/queryset/test_queryset_aggregation.py @@ -1,3 +1,6 @@ +import sys +from unittest.mock import patch + import pytest from pymongo.read_preferences import ReadPreference @@ -373,3 +376,30 @@ class SomeDoc(Document): res = list(SomeDoc.objects.aggregate(pipeline)) assert len(res) == 1 assert res[0]["count"] == 2 + + def test_aggregate_search_used_as_initial_step_before_cls_implicit_step(self): + class SearchableDoc(Document): + first_name = StringField() + last_name = StringField() + + privileged_step = { + "$search": { + "index": "default", + "autocomplete": { + "query": "foo", + "path": "first_name", + }, + } + } + pipeline = [privileged_step] + + # Search requires an Atlas instance, so we instead mock the aggregation call to inspect the final pipeline + with patch.object(SearchableDoc, "_collection") as coll_mk: + SearchableDoc.objects(last_name="bar").aggregate(pipeline) + if sys.version_info < (3, 8): + final_pipeline = coll_mk.aggregate.call_args_list[0][0][0] + else: + final_pipeline = coll_mk.aggregate.call_args_list[0].args[0] + assert len(final_pipeline) == 2 + # $search moved before the $match on last_name + assert final_pipeline[0] == privileged_step