Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_no_dereference_thread_safetyness #2830

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Development
- Allow gt/gte/lt/lte/ne operators to be used with a list as value on ListField #2813
- Switch tox to use pytest instead of legacy `python setup.py test` #2804
- Add support for timeseries collection #2661
- improve ReferenceField wrong usage detection
- Fix no_dereference thread-safetyness #2830
- BREAKING CHANGE: max_length in ListField is now keyword only on ListField signature
- BREAKING CHANGE: Force `field` argument of ListField/DictField to be a field instance (e.g ListField(StringField()) instead of ListField(StringField)

Changes in 0.28.2
=================
Expand Down
7 changes: 6 additions & 1 deletion mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

fields = cls._fields
if not _auto_dereference:
# if auto_deref is turned off, we copy the fields so
# we can mutate the auto_dereference of the fields
fields = copy.deepcopy(fields)

# Apply field-name / db-field conversion
for field_name, field in fields.items():
field._auto_dereference = _auto_dereference
field.set_auto_dereferencing(
_auto_dereference
) # align the field's auto-dereferencing with the document's
if field.db_field in data:
value = data[field.db_field]
try:
Expand Down
52 changes: 50 additions & 2 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import operator
import threading
import weakref

import pymongo
Expand All @@ -16,6 +18,19 @@
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")


@contextlib.contextmanager
def _no_dereference_for_fields(*fields):
"""Context manager for temporarily disabling a Field's auto-dereferencing
(meant to be used from no_dereference context manager)"""
try:
for field in fields:
field._incr_no_dereference_context()
yield None
finally:
for field in fields:
field._decr_no_dereference_context()


class BaseField:
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
Expand All @@ -24,7 +39,7 @@ class BaseField:
name = None # set in TopLevelDocumentMetaclass
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
_auto_dereference = True
_thread_local_storage = threading.local()

# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
Expand Down Expand Up @@ -85,6 +100,8 @@ def __init__(
self.sparse = sparse
self._owner_document = None

self.__auto_dereference = True

# Make sure db_field is a string (if it's explicitly defined).
if self.db_field is not None and not isinstance(self.db_field, str):
raise TypeError("db_field should be a string.")
Expand Down Expand Up @@ -120,6 +137,33 @@ def __init__(
self.creation_counter = BaseField.creation_counter
BaseField.creation_counter += 1

def set_auto_dereferencing(self, value):
self.__auto_dereference = value

@property
def _no_dereference_context_local(self):
if not hasattr(self._thread_local_storage, "no_dereference_context"):
self._thread_local_storage.no_dereference_context = 0
return self._thread_local_storage.no_dereference_context

@property
def _no_dereference_context_is_set(self):
return self._no_dereference_context_local > 0

def _incr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self._no_dereference_context_local + 1
)

def _decr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self._no_dereference_context_local - 1
)

@property
def _auto_dereference(self):
return self.__auto_dereference and not self._no_dereference_context_is_set

def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document."""
if instance is None:
Expand Down Expand Up @@ -268,6 +312,10 @@ class ComplexBaseField(BaseField):
"""

def __init__(self, field=None, **kwargs):
if field is not None and not isinstance(field, BaseField):
raise TypeError(
f"field argument must be a Field instance (e.g {self.__class__.__name__}(StringField()))"
)
self.field = field
super().__init__(**kwargs)

Expand Down Expand Up @@ -375,7 +423,7 @@ def to_python(self, value):
return value

if self.field:
self.field._auto_dereference = self._auto_dereference
self.field.set_auto_dereferencing(self._auto_dereference)
value_dict = {
key: self.field.to_python(item) for key, item in value.items()
}
Expand Down
46 changes: 20 additions & 26 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import threading
from contextlib import contextmanager

from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern

from mongoengine.base.fields import _no_dereference_for_fields
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
Expand All @@ -22,6 +24,7 @@

class MyThreadLocals(threading.local):
def __init__(self):
# {DocCls: count} keeping track of classes with an active no_dereference context
self.no_dereferencing_class = {}


Expand Down Expand Up @@ -126,46 +129,37 @@ def __exit__(self, t, value, traceback):
self.cls._get_collection_name = self.ori_get_collection_name


class no_dereference:
@contextlib.contextmanager
def no_dereference(cls):
"""no_dereference context manager.

Turns off all dereferencing in Documents for the duration of the context
manager::

with no_dereference(Group):
Group.objects.find()
Group.objects()
"""

def __init__(self, cls):
"""Construct the no_dereference context manager.

:param cls: the class to turn dereferencing off on
"""
self.cls = cls
try:
cls = cls

ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
ComplexBaseField = _import_class("ComplexBaseField")

self.deref_fields = [
k
for k, v in self.cls._fields.items()
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
deref_fields = [
field
for name, field in cls._fields.items()
if isinstance(
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
)
]

def __enter__(self):
"""Change the objects default and _auto_dereference values."""
_register_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False

def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
_unregister_no_dereferencing_for_class(self.cls)
_register_no_dereferencing_for_class(cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
with _no_dereference_for_fields(*deref_fields):
yield None
finally:
_unregister_no_dereferencing_for_class(cls)


class no_sub_classes:
Expand All @@ -180,7 +174,7 @@ class no_sub_classes:
def __init__(self, cls):
"""Construct the no_sub_classes context manager.

:param cls: the class to turn querying sub classes on
:param cls: the class to turn querying subclasses on
"""
self.cls = cls
self.cls_initial_subclasses = None
Expand Down
16 changes: 8 additions & 8 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import socket
import time
import uuid
from inspect import isclass
from io import BytesIO
from operator import itemgetter

Expand Down Expand Up @@ -707,7 +708,6 @@ class EmbeddedDocumentField(BaseField):
"""

def __init__(self, document_type, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not (
isinstance(document_type, str)
or issubclass(document_type, EmbeddedDocument)
Expand Down Expand Up @@ -910,9 +910,9 @@ class ListField(ComplexBaseField):
Required means it cannot be empty - as the default for ListFields is []
"""

def __init__(self, field=None, max_length=None, **kwargs):
def __init__(self, field=None, *, max_length=None, **kwargs):
self.max_length = max_length
kwargs.setdefault("default", lambda: [])
kwargs.setdefault("default", list)
super().__init__(field=field, **kwargs)

def __get__(self, instance, owner):
Expand Down Expand Up @@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField):
"""

def __init__(self, field=None, *args, **kwargs):
self._auto_dereference = False

kwargs.setdefault("default", lambda: {})
kwargs.setdefault("default", dict)
super().__init__(*args, field=field, **kwargs)
self.set_auto_dereferencing(False)

def validate(self, value):
"""Make sure that a list of valid fields is being used."""
Expand Down Expand Up @@ -1151,8 +1150,9 @@ def __init__(
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(document_type, str) and not issubclass(
document_type, Document
if not (
isinstance(document_type, str)
or (isclass(document_type) and issubclass(document_type, Document))
):
self.error(
"Argument to ReferenceField constructor must be a "
Expand Down
54 changes: 0 additions & 54 deletions python-mongoengine.spec

This file was deleted.

10 changes: 10 additions & 0 deletions tests/fields/test_complex_base_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from mongoengine.base import ComplexBaseField
from tests.utils import MongoDBTestCase


class TestComplexBaseField(MongoDBTestCase):
def test_field_validation(self):
with pytest.raises(TypeError, match="field argument must be a Field instance"):
ComplexBaseField("test")
Loading
Loading