diff --git a/docs/changelog.rst b/docs/changelog.rst index 014a83991..ff2dd38c1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -23,6 +23,7 @@ Development - BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309 - BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858 - Improve error message in case a document assigned to a ReferenceField wasn't saved yet #1955 +- BugFix - Take `where()` into account when using `.modify()`, as in MyDocument.objects().where("this[field] >= this[otherfield]").modify(field='new') #2044 Changes in 0.29.0 ================= diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index f04ef06c5..2db97ddb7 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -727,6 +727,11 @@ def modify( queryset = self.clone() query = queryset._query + + if self._where_clause: + where_clause = self._sub_js_fields(self._where_clause) + query["$where"] = where_clause + if not remove: update = transform.update(queryset._document, **update) sort = queryset._ordering diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 0b8773c2b..8386249f2 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -28,6 +28,7 @@ from mongoengine.queryset.base import BaseQuerySet from tests.utils import ( db_ops_tracker, + get_as_pymongo, requires_mongodb_gte_42, requires_mongodb_gte_44, requires_mongodb_lt_42, @@ -4456,7 +4457,7 @@ class Comment(Document): ] assert ([("_cls", 1), ("message", 1)], False, False) in info - def test_where(self): + def test_where_query(self): """Ensure that where clauses work.""" class IntPair(Document): @@ -4499,6 +4500,60 @@ class IntPair(Document): with pytest.raises(TypeError): list(IntPair.objects.where(fielda__gte=3)) + def test_where_query_field_name_subs(self): + class DomainObj(Document): + field_1 = StringField(db_field="field_2") + + DomainObj.drop_collection() + + DomainObj(field_1="test").save() + + obj = DomainObj.objects.where("this[~field_1] == 'NOTMATCHING'") + assert not list(obj) + + obj = DomainObj.objects.where("this[~field_1] == 'test'") + assert list(obj) + + def test_where_modify(self): + class DomainObj(Document): + field = StringField() + + DomainObj.drop_collection() + + DomainObj(field="test").save() + + obj = DomainObj.objects.where("this[~field] == 'NOTMATCHING'") + assert not list(obj) + + obj = DomainObj.objects.where("this[~field] == 'test'") + assert list(obj) + + qs = DomainObj.objects.where("this[~field] == 'NOTMATCHING'").modify( + field="new" + ) + assert not qs + + qs = DomainObj.objects.where("this[~field] == 'test'").modify(field="new") + assert qs + + def test_where_modify_field_name_subs(self): + class DomainObj(Document): + field_1 = StringField(db_field="field_2") + + DomainObj.drop_collection() + + DomainObj(field_1="test").save() + + obj = DomainObj.objects.where("this[~field_1] == 'NOTMATCHING'").modify( + field_1="new" + ) + assert not obj + + obj = DomainObj.objects.where("this[~field_1] == 'test'").modify(field_1="new") + assert obj + + assert get_as_pymongo(obj) == {"_id": obj.id, "field_2": "new"} + def test_scalar(self): class Organization(Document): name = StringField()