Skip to content

Commit 9ca437d

Browse files
committed
fix GenericReference iterable query (i.e. __in)
This change adds the ``_ref`` or ``_ref.$id`` prefix to a query if all values in an iterable query (i.e. ``__in``) are ``ObjectId``s or ``DBRef``s and raises an error for a mixed query which will only work for documents. These could possibly be compiled into an ``{$or: ...}`` query, but the automatic expansion can be added as necessary.
1 parent e51ee40 commit 9ca437d

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

mongoengine/queryset/transform.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ def query(_doc_cls=None, **kwargs):
129129

130130
singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"]
131131
singular_ops += STRING_OPERATORS
132+
is_iterable = False
132133
if op in singular_ops:
133134
value = field.prepare_query_value(op, value)
134135

135136
if isinstance(field, CachedReferenceField) and value:
136137
value = value["_id"]
137138

138139
elif op in ("in", "nin", "all", "near") and not isinstance(value, dict):
140+
is_iterable = True
139141
# Raise an error if the in/nin/all/near param is not iterable.
140142
value = _prepare_query_for_iterable(field, op, value)
141143

@@ -144,10 +146,26 @@ def query(_doc_cls=None, **kwargs):
144146
# * If the value is a DBRef, the key should be "field_name._ref".
145147
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
146148
if isinstance(field, GenericReferenceField):
147-
if isinstance(value, DBRef):
149+
if (
150+
isinstance(value, DBRef)
151+
or (is_iterable and all(isinstance(v, DBRef) for v in value))
152+
):
148153
parts[-1] += "._ref"
149-
elif isinstance(value, ObjectId):
154+
elif (
155+
isinstance(value, ObjectId):
156+
or (is_iterable and all(isinstance(v, ObjectId) for v in value))
157+
):
150158
parts[-1] += "._ref.$id"
159+
elif (
160+
is_iterable
161+
and any(isinstance(v, DBRef) for v in value)
162+
and any(isinstance(v, ObjectId) for v in value)
163+
):
164+
raise ValueError(
165+
"The `in`, `nin`, `all`, or `near`-operators cannot "
166+
"be applied to mixed queries of DBRef/ObjectId/%s"
167+
% _doc_cls.__name__
168+
)
151169

152170
# if op and op not in COMPARISON_OPERATORS:
153171
if op:

0 commit comments

Comments
 (0)