diff --git a/docs/changelog.rst b/docs/changelog.rst index 65aef227d..b8f000174 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,10 @@ Development - Allow gt/gte/lt/lte/ne operators to be used with a list as value on ListField #2813 - Switch tox to use pytest instead of legacy `python setup.py test` #2804 - Add support for timeseries collection #2661 +- improve ReferenceField wrong usage detection +- Fix no_dereference thread-safetyness #2830 +- BREAKING CHANGE: max_length in ListField is now keyword only on ListField signature +- BREAKING CHANGE: Force `field` argument of ListField/DictField to be a field instance (e.g ListField(StringField()) instead of ListField(StringField) Changes in 0.28.2 ================= diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 053242234..595690829 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False): fields = cls._fields if not _auto_dereference: + # if auto_deref is turned off, we copy the fields so + # we can mutate the auto_dereference of the fields fields = copy.deepcopy(fields) + # Apply field-name / db-field conversion for field_name, field in fields.items(): - field._auto_dereference = _auto_dereference + field.set_auto_dereferencing( + _auto_dereference + ) # align the field's auto-dereferencing with the document's if field.db_field in data: value = data[field.db_field] try: diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 037e916ff..cead14449 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -1,4 +1,6 @@ +import contextlib import operator +import threading import weakref import pymongo @@ -16,6 +18,19 @@ __all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +@contextlib.contextmanager +def _no_dereference_for_fields(*fields): + """Context manager for temporarily disabling a Field's auto-dereferencing + (meant to be used from no_dereference context manager)""" + try: + for field in fields: + field._incr_no_dereference_context() + yield None + finally: + for field in fields: + field._decr_no_dereference_context() + + class BaseField: """A base class for fields in a MongoDB document. Instances of this class may be added to subclasses of `Document` to define a document's schema. @@ -24,7 +39,7 @@ class BaseField: name = None # set in TopLevelDocumentMetaclass _geo_index = False _auto_gen = False # Call `generate` to generate a value - _auto_dereference = True + _thread_local_storage = threading.local() # These track each time a Field instance is created. Used to retain order. # The auto_creation_counter is used for fields that MongoEngine implicitly @@ -85,6 +100,8 @@ def __init__( self.sparse = sparse self._owner_document = None + self.__auto_dereference = True + # Make sure db_field is a string (if it's explicitly defined). if self.db_field is not None and not isinstance(self.db_field, str): raise TypeError("db_field should be a string.") @@ -120,6 +137,33 @@ def __init__( self.creation_counter = BaseField.creation_counter BaseField.creation_counter += 1 + def set_auto_dereferencing(self, value): + self.__auto_dereference = value + + @property + def _no_dereference_context_local(self): + if not hasattr(self._thread_local_storage, "no_dereference_context"): + self._thread_local_storage.no_dereference_context = 0 + return self._thread_local_storage.no_dereference_context + + @property + def _no_dereference_context_is_set(self): + return self._no_dereference_context_local > 0 + + def _incr_no_dereference_context(self): + self._thread_local_storage.no_dereference_context = ( + self._no_dereference_context_local + 1 + ) + + def _decr_no_dereference_context(self): + self._thread_local_storage.no_dereference_context = ( + self._no_dereference_context_local - 1 + ) + + @property + def _auto_dereference(self): + return self.__auto_dereference and not self._no_dereference_context_is_set + def __get__(self, instance, owner): """Descriptor for retrieving a value from a field in a document.""" if instance is None: @@ -268,6 +312,10 @@ class ComplexBaseField(BaseField): """ def __init__(self, field=None, **kwargs): + if field is not None and not isinstance(field, BaseField): + raise TypeError( + f"field argument must be a Field instance (e.g {self.__class__.__name__}(StringField()))" + ) self.field = field super().__init__(**kwargs) @@ -375,7 +423,7 @@ def to_python(self, value): return value if self.field: - self.field._auto_dereference = self._auto_dereference + self.field.set_auto_dereferencing(self._auto_dereference) value_dict = { key: self.field.to_python(item) for key, item in value.items() } diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index f16753eea..e8eee1e38 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -1,9 +1,11 @@ +import contextlib import threading from contextlib import contextmanager from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern +from mongoengine.base.fields import _no_dereference_for_fields from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.pymongo_support import count_documents @@ -22,6 +24,7 @@ class MyThreadLocals(threading.local): def __init__(self): + # {DocCls: count} keeping track of classes with an active no_dereference context self.no_dereferencing_class = {} @@ -126,46 +129,37 @@ def __exit__(self, t, value, traceback): self.cls._get_collection_name = self.ori_get_collection_name -class no_dereference: +@contextlib.contextmanager +def no_dereference(cls): """no_dereference context manager. Turns off all dereferencing in Documents for the duration of the context manager:: with no_dereference(Group): - Group.objects.find() + Group.objects() """ - - def __init__(self, cls): - """Construct the no_dereference context manager. - - :param cls: the class to turn dereferencing off on - """ - self.cls = cls + try: + cls = cls ReferenceField = _import_class("ReferenceField") GenericReferenceField = _import_class("GenericReferenceField") ComplexBaseField = _import_class("ComplexBaseField") - self.deref_fields = [ - k - for k, v in self.cls._fields.items() - if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField)) + deref_fields = [ + field + for name, field in cls._fields.items() + if isinstance( + field, (ReferenceField, GenericReferenceField, ComplexBaseField) + ) ] - def __enter__(self): - """Change the objects default and _auto_dereference values.""" - _register_no_dereferencing_for_class(self.cls) - - for field in self.deref_fields: - self.cls._fields[field]._auto_dereference = False - - def __exit__(self, t, value, traceback): - """Reset the default and _auto_dereference values.""" - _unregister_no_dereferencing_for_class(self.cls) + _register_no_dereferencing_for_class(cls) - for field in self.deref_fields: - self.cls._fields[field]._auto_dereference = True + with _no_dereference_for_fields(*deref_fields): + yield None + finally: + _unregister_no_dereferencing_for_class(cls) class no_sub_classes: @@ -180,7 +174,7 @@ class no_sub_classes: def __init__(self, cls): """Construct the no_sub_classes context manager. - :param cls: the class to turn querying sub classes on + :param cls: the class to turn querying subclasses on """ self.cls = cls self.cls_initial_subclasses = None diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 27a0826a4..5d8c1395e 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -6,6 +6,7 @@ import socket import time import uuid +from inspect import isclass from io import BytesIO from operator import itemgetter @@ -707,7 +708,6 @@ class EmbeddedDocumentField(BaseField): """ def __init__(self, document_type, **kwargs): - # XXX ValidationError raised outside of the "validate" method. if not ( isinstance(document_type, str) or issubclass(document_type, EmbeddedDocument) @@ -910,9 +910,9 @@ class ListField(ComplexBaseField): Required means it cannot be empty - as the default for ListFields is [] """ - def __init__(self, field=None, max_length=None, **kwargs): + def __init__(self, field=None, *, max_length=None, **kwargs): self.max_length = max_length - kwargs.setdefault("default", lambda: []) + kwargs.setdefault("default", list) super().__init__(field=field, **kwargs) def __get__(self, instance, owner): @@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField): """ def __init__(self, field=None, *args, **kwargs): - self._auto_dereference = False - - kwargs.setdefault("default", lambda: {}) + kwargs.setdefault("default", dict) super().__init__(*args, field=field, **kwargs) + self.set_auto_dereferencing(False) def validate(self, value): """Make sure that a list of valid fields is being used.""" @@ -1151,8 +1150,9 @@ def __init__( :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. """ # XXX ValidationError raised outside of the "validate" method. - if not isinstance(document_type, str) and not issubclass( - document_type, Document + if not ( + isinstance(document_type, str) + or (isclass(document_type) and issubclass(document_type, Document)) ): self.error( "Argument to ReferenceField constructor must be a " diff --git a/python-mongoengine.spec b/python-mongoengine.spec deleted file mode 100644 index 635c779fd..000000000 --- a/python-mongoengine.spec +++ /dev/null @@ -1,54 +0,0 @@ -# sitelib for noarch packages, sitearch for others (remove the unneeded one) -%{!?python_sitelib: %global python_sitelib %(%{__python} -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")} -%{!?python_sitearch: %global python_sitearch %(%{__python} -c "from distutils.sysconfig import get_python_lib; print(get_python_lib(1))")} - -%define srcname mongoengine - -Name: python-%{srcname} -Version: 0.8.7 -Release: 1%{?dist} -Summary: A Python Document-Object Mapper for working with MongoDB - -Group: Development/Libraries -License: MIT -URL: https://github.com/MongoEngine/mongoengine -Source0: %{srcname}-%{version}.tar.bz2 - -BuildRequires: python-devel -BuildRequires: python-setuptools - -Requires: mongodb -Requires: pymongo -Requires: python-blinker -Requires: python-imaging - - -%description -MongoEngine is an ORM-like layer on top of PyMongo. - -%prep -%setup -q -n %{srcname}-%{version} - - -%build -# Remove CFLAGS=... for noarch packages (unneeded) -CFLAGS="$RPM_OPT_FLAGS" %{__python} setup.py build - - -%install -rm -rf $RPM_BUILD_ROOT -%{__python} setup.py install -O1 --skip-build --root $RPM_BUILD_ROOT - -%clean -rm -rf $RPM_BUILD_ROOT - -%files -%defattr(-,root,root,-) -%doc docs AUTHORS LICENSE README.rst -# For noarch packages: sitelib - %{python_sitelib}/* -# For arch-specific packages: sitearch -# %{python_sitearch}/* - -%changelog -* See: http://docs.mongoengine.org/en/latest/changelog.html diff --git a/tests/fields/test_complex_base_field.py b/tests/fields/test_complex_base_field.py new file mode 100644 index 000000000..accda2f78 --- /dev/null +++ b/tests/fields/test_complex_base_field.py @@ -0,0 +1,10 @@ +import pytest + +from mongoengine.base import ComplexBaseField +from tests.utils import MongoDBTestCase + + +class TestComplexBaseField(MongoDBTestCase): + def test_field_validation(self): + with pytest.raises(TypeError, match="field argument must be a Field instance"): + ComplexBaseField("test") diff --git a/tests/fields/test_dict_field.py b/tests/fields/test_dict_field.py index 4da29d9a3..c2c6ea1fd 100644 --- a/tests/fields/test_dict_field.py +++ b/tests/fields/test_dict_field.py @@ -116,39 +116,35 @@ class BlogPost(Document): post.reload() assert post.info["authors"] == [] - def test_dictfield_dump_document(self): + def test_dictfield_dump_document_with_inheritance__cls(self): """Ensure a DictField can handle another document's dump.""" class Doc(Document): field = DictField() - class ToEmbed(Document): - id = IntField(primary_key=True, default=1) - recursive = DictField() - class ToEmbedParent(Document): - id = IntField(primary_key=True, default=1) + id = IntField(primary_key=True) recursive = DictField() meta = {"allow_inheritance": True} class ToEmbedChild(ToEmbedParent): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - to_embed_recursive = ToEmbed(id=1).save() - to_embed = ToEmbed( - id=2, recursive=to_embed_recursive.to_mongo().to_dict() - ).save() - doc = Doc(field=to_embed.to_mongo().to_dict()) - doc.save() - assert isinstance(doc.field, dict) - assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} - # Same thing with a Document with a _cls field + Doc.drop_collection() + ToEmbedParent.drop_collection() + + # with a Document with a _cls field to_embed_recursive = ToEmbedChild(id=1).save() to_embed_child = ToEmbedChild( id=2, recursive=to_embed_recursive.to_mongo().to_dict() ).save() - doc = Doc(field=to_embed_child.to_mongo().to_dict()) + + doc_dump_as_dict = to_embed_child.to_mongo().to_dict() + doc = Doc(field=doc_dump_as_dict) + assert Doc.field._auto_dereference is False + assert isinstance(doc.field, dict) # depends on auto_dereference doc.save() assert isinstance(doc.field, dict) expected = { @@ -162,6 +158,30 @@ class ToEmbedChild(ToEmbedParent): } assert doc.field == expected + # _ = Doc.objects.first() + # assert Doc.field._auto_dereference is False # Fails, bug #2831 + # doc = Doc(field=doc_dump_as_dict) + # assert isinstance(doc.field, dict) # Fails, bug #2831 + + def test_dictfield_dump_document_no_inheritance(self): + """Ensure a DictField can handle another document's dump.""" + + class Doc(Document): + field = DictField() + + class ToEmbed(Document): + id = IntField(primary_key=True) + recursive = DictField() + + to_embed_recursive = ToEmbed(id=1).save() + to_embed = ToEmbed( + id=2, recursive=to_embed_recursive.to_mongo().to_dict() + ).save() + doc = Doc(field=to_embed.to_mongo().to_dict()) + doc.save() + assert isinstance(doc.field, dict) + assert doc.field == {"_id": 2, "recursive": {"_id": 1, "recursive": {}}} + def test_dictfield_strict(self): """Ensure that dict field handles validation if provided a strict field type.""" diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index d95e2fce0..5655f12e2 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -39,6 +39,7 @@ EmbeddedDocumentList, _document_registry, ) +from mongoengine.base.fields import _no_dereference_for_fields from mongoengine.errors import DeprecatedError from tests.utils import MongoDBTestCase @@ -1373,17 +1374,19 @@ class Bar(Document): # Reference is no longer valid foo.delete() bar = Bar.objects.get() + with pytest.raises(DoesNotExist): bar.ref + with pytest.raises(DoesNotExist): bar.generic_ref # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() expected = foo.to_dbref() - bar._fields["ref"]._auto_dereference = False + bar._fields["ref"].set_auto_dereferencing(False) assert bar.ref == expected - bar._fields["generic_ref"]._auto_dereference = False + bar._fields["generic_ref"].set_auto_dereferencing(False) assert bar.generic_ref == {"_ref": expected, "_cls": "Foo"} def test_list_item_dereference(self): @@ -2732,5 +2735,33 @@ class CustomData(Document): assert custom_data["a"] == CustomData.c_field.custom_data["a"] +class TestUtils(MongoDBTestCase): + def test__no_dereference_for_fields(self): + class User(Document): + name = StringField() + + class Group(Document): + member = ReferenceField(User) + + User.drop_collection() + Group.drop_collection() + + user1 = User(name="user1") + user1.save() + + group = Group(member=user1) + group.save() + + # Test all inside the context mgr, from class field + with _no_dereference_for_fields(Group.member): + group = Group.objects.first() + assert isinstance(group.member, DBRef) + + # Test instance fetched outside context mgr, patch on instance field + group = Group.objects.first() + with _no_dereference_for_fields(group._fields["member"]): + assert isinstance(group.member, DBRef) + + if __name__ == "__main__": unittest.main() diff --git a/tests/fields/test_reference_field.py b/tests/fields/test_reference_field.py index afbf50045..94869f2ea 100644 --- a/tests/fields/test_reference_field.py +++ b/tests/fields/test_reference_field.py @@ -6,6 +6,26 @@ class TestReferenceField(MongoDBTestCase): + def test_reference_field_fails_init_wrong_document_type(self): + class User(Document): + name = StringField() + + ERROR_MSG = "Argument to ReferenceField constructor must be a document class or a string" + # fails if given an instance + with pytest.raises(ValidationError, match=ERROR_MSG): + + class Test(Document): + author = ReferenceField(User()) + + class NonDocumentSubClass: + pass + + # fails if given an non Document subclass + with pytest.raises(ValidationError, match=ERROR_MSG): + + class Test(Document): # noqa: F811 + author = ReferenceField(NonDocumentSubClass) + def test_reference_validation(self): """Ensure that invalid document objects cannot be assigned to reference fields. diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index f5634d77a..7b41a894b 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -5326,7 +5326,8 @@ class User(Document): assert isinstance(qs.first().organization, Organization) - assert isinstance(qs.no_dereference().first().organization, DBRef) + user = qs.no_dereference().first() + assert isinstance(user.organization, DBRef) assert isinstance(qs_user.organization, Organization) assert isinstance(qs.first().organization, Organization) diff --git a/tests/test_ci.py b/tests/test_ci.py index 04a800ebf..00068f776 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -3,7 +3,7 @@ def test_ci_placeholder(): # setup the tox venv without running the test suite # if we simply skip all test with pytest -k=wrong_pattern # pytest command would return with exit_code=5 (i.e "no tests run") - # making travis fail + # making pipeline fail # this empty test is the recommended way to handle this # as described in https://github.com/pytest-dev/pytest/issues/2393 pass diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 69c8931dd..287ce7df5 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -172,8 +172,44 @@ class Group(Document): assert isinstance(group.ref, User) assert isinstance(group.generic, User) + def test_no_dereference_context_manager_thread_safe(self): + """Ensure no_dereference context manager works in threaded condition""" + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=False) + + User.drop_collection() + Group.drop_collection() + + user = User(name="user 1").save() + Group(ref=user).save() + + def run_in_thread(id): + time.sleep(random.uniform(0.1, 0.5)) # Force desync of threads + if id % 2 == 0: + with no_dereference(Group): + for i in range(20): + time.sleep(random.uniform(0.1, 0.5)) + assert Group.ref._auto_dereference is False + group = Group.objects.first() + assert isinstance(group.ref, DBRef) + else: + for i in range(20): + time.sleep(random.uniform(0.1, 0.5)) + assert Group.ref._auto_dereference is True + group = Group.objects.first() + assert isinstance(group.ref, User) + + threads = [ + TestableThread(target=run_in_thread, args=(id,)) for id in range(100) + ] + _ = [th.start() for th in threads] + _ = [th.join() for th in threads] + def test_no_dereference_context_manager_nested(self): - """Ensure that DBRef items in ListFields aren't dereferenced.""" class User(Document): name = StringField() @@ -205,20 +241,6 @@ class Group(Document): group = Group.objects.first() assert isinstance(group.ref, User) - def run_in_thread(id): - time.sleep(random.uniform(0.1, 0.5)) # Force desync of threads - if id % 2 == 0: - with no_dereference(Group): - group = Group.objects.first() - assert isinstance(group.ref, DBRef) - else: - group = Group.objects.first() - assert isinstance(group.ref, User) - - threads = [TestableThread(target=run_in_thread, args=(id,)) for id in range(10)] - _ = [th.start() for th in threads] - _ = [th.join() for th in threads] - def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced"""