Skip to content

Commit 95fc3d2

Browse files
committed
fix GenericReference iterable query (i.e. __in)
This change adds the ``_ref`` or ``_ref.$id`` prefix to a query if all values in an iterable query (i.e. ``__in``) are ``ObjectId``s or ``DBRef``s and raises an error for a mixed query which will only work for documents. These could possibly be compiled into an ``{$or: ...}`` query, but the automatic expansion can be added as necessary.
1 parent 11fd30b commit 95fc3d2

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

mongoengine/queryset/transform.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ def query(_doc_cls=None, **kwargs):
129129

130130
singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"]
131131
singular_ops += STRING_OPERATORS
132+
is_iterable = False
132133
if op in singular_ops:
133134
value = field.prepare_query_value(op, value)
134135

135136
if isinstance(field, CachedReferenceField) and value:
136137
value = value["_id"]
137138

138139
elif op in ("in", "nin", "all", "near") and not isinstance(value, dict):
140+
is_iterable = True
139141
# Raise an error if the in/nin/all/near param is not iterable.
140142
value = _prepare_query_for_iterable(field, op, value)
141143

@@ -144,10 +146,24 @@ def query(_doc_cls=None, **kwargs):
144146
# * If the value is a DBRef, the key should be "field_name._ref".
145147
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
146148
if isinstance(field, GenericReferenceField):
147-
if isinstance(value, DBRef):
149+
if isinstance(value, DBRef) or (
150+
is_iterable and all(isinstance(v, DBRef) for v in value)
151+
):
148152
parts[-1] += "._ref"
149-
elif isinstance(value, ObjectId):
153+
elif isinstance(value, ObjectId) or (
154+
is_iterable and all(isinstance(v, ObjectId) for v in value)
155+
):
150156
parts[-1] += "._ref.$id"
157+
elif (
158+
is_iterable
159+
and any(isinstance(v, DBRef) for v in value)
160+
and any(isinstance(v, ObjectId) for v in value)
161+
):
162+
raise ValueError(
163+
"The `in`, `nin`, `all`, or `near`-operators cannot "
164+
"be applied to mixed queries of DBRef/ObjectId/%s"
165+
% _doc_cls.__name__
166+
)
151167

152168
# if op and op not in COMPARISON_OPERATORS:
153169
if op:

tests/queryset/test_transform.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,36 @@ class Shop(Document):
396396

397397
Shop.drop_collection()
398398

399+
def test_transform_generic_reference_field(self):
400+
class Object(Document):
401+
field = GenericReferenceField()
402+
403+
Object.drop_collection()
404+
objects = Object.objects.insert([Object() for _ in range(8)])
405+
# singular queries
406+
assert transform.query(Object, field=objects[0].pk) == {
407+
"field._ref.$id": objects[0].pk
408+
}
409+
assert transform.query(Object, field=objects[1].to_dbref()) == {
410+
"field._ref": objects[1].pk
411+
}
412+
413+
# iterable queries
414+
assert transform.query(
415+
Object, field__in=[objects[2].pk, objects[3].pk]
416+
) == {"field._ref.$id": {"$in": [objects[2].pk, objects[3].pk]}}
417+
assert transform.query(
418+
Object, field__in=[objects[4].to_dbref(), objects[5].to_dbref()]
419+
) == {"field._ref": {"$in": [objects[4].pk, objects[5].pk]}}
420+
421+
# invalid query
422+
with pytest.raises(match="cannot be applied to mixed queries"):
423+
transform.query(
424+
Object, field__in=[objects[6].pk, objects[7].to_dbref()]
425+
)
426+
427+
Object.drop_collection()
428+
399429

400430
if __name__ == "__main__":
401431
unittest.main()

0 commit comments

Comments
 (0)