Skip to content

Commit

Permalink
Fix no_dereference thread safetyness + force ComplexField's field arg…
Browse files Browse the repository at this point in the history
… to be a BaseField instance + make max_length a keyword only arg on ListField
  • Loading branch information
bagerard committed Aug 21, 2024
1 parent 6f7f7b7 commit 562293d
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 127 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Development
===========
- (Fill this out as you fix issues and develop your features).
- Switch tox to use pytest instead of legacy `python setup.py test` #2804
- improve ReferenceField wrong usage detection
- Fix no_dereference thread-safetyness #2830
- BREAKING CHANGE: max_length in ListField is now keyword only on ListField signature
- BREAKING CHANGE: Force `field` argument of ListField/DictField to be a field instance (e.g ListField(StringField()) instead of ListField(StringField)

Changes in 0.28.2
=================
Expand Down
7 changes: 6 additions & 1 deletion mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

fields = cls._fields
if not _auto_dereference:
# if auto_deref is turned off, we copy the fields so
# we can mutate the auto_dereference of the fields
fields = copy.deepcopy(fields)

# Apply field-name / db-field conversion
for field_name, field in fields.items():
field._auto_dereference = _auto_dereference
field.set_auto_dereferencing(
_auto_dereference
) # align the field's auto-dereferencing with the document's
if field.db_field in data:
value = data[field.db_field]
try:
Expand Down
52 changes: 50 additions & 2 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import operator
import threading
import weakref

import pymongo
Expand All @@ -16,6 +18,19 @@
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")


@contextlib.contextmanager
def _no_dereference_for_fields(*fields):
"""Context manager for temporarily disabling a Field's auto-dereferencing
(meant to be used from no_dereference context manager)"""
try:
for field in fields:
field._incr_no_dereference_context()
yield None
finally:
for field in fields:
field._decr_no_dereference_context()


class BaseField:
"""A base class for fields in a MongoDB document. Instances of this class
may be added to subclasses of `Document` to define a document's schema.
Expand All @@ -24,7 +39,7 @@ class BaseField:
name = None # set in TopLevelDocumentMetaclass
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
_auto_dereference = True
_thread_local_storage = threading.local()

# These track each time a Field instance is created. Used to retain order.
# The auto_creation_counter is used for fields that MongoEngine implicitly
Expand Down Expand Up @@ -85,6 +100,8 @@ def __init__(
self.sparse = sparse
self._owner_document = None

self.__auto_dereference = True

# Make sure db_field is a string (if it's explicitly defined).
if self.db_field is not None and not isinstance(self.db_field, str):
raise TypeError("db_field should be a string.")
Expand Down Expand Up @@ -120,6 +137,33 @@ def __init__(
self.creation_counter = BaseField.creation_counter
BaseField.creation_counter += 1

def set_auto_dereferencing(self, value):
self.__auto_dereference = value

@property
def _no_dereference_context_local(self):
if not hasattr(self._thread_local_storage, "no_dereference_context"):
self._thread_local_storage.no_dereference_context = 0
return self._thread_local_storage.no_dereference_context

@property
def _no_dereference_context_is_set(self):
return self._no_dereference_context_local > 0

def _incr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self._no_dereference_context_local + 1
)

def _decr_no_dereference_context(self):
self._thread_local_storage.no_dereference_context = (
self._no_dereference_context_local - 1
)

@property
def _auto_dereference(self):
return self.__auto_dereference and not self._no_dereference_context_is_set

def __get__(self, instance, owner):
"""Descriptor for retrieving a value from a field in a document."""
if instance is None:
Expand Down Expand Up @@ -268,6 +312,10 @@ class ComplexBaseField(BaseField):
"""

def __init__(self, field=None, **kwargs):
if field is not None and not isinstance(field, BaseField):
raise TypeError(
f"field argument must be a Field instance (e.g {self.__class__.__name__}(StringField()))"
)
self.field = field
super().__init__(**kwargs)

Expand Down Expand Up @@ -375,7 +423,7 @@ def to_python(self, value):
return value

if self.field:
self.field._auto_dereference = self._auto_dereference
self.field.set_auto_dereferencing(self._auto_dereference)
value_dict = {
key: self.field.to_python(item) for key, item in value.items()
}
Expand Down
46 changes: 20 additions & 26 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import threading
from contextlib import contextmanager

from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern

from mongoengine.base.fields import _no_dereference_for_fields
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
Expand All @@ -22,6 +24,7 @@

class MyThreadLocals(threading.local):
def __init__(self):
# {DocCls: count} keeping track of classes with an active no_dereference context
self.no_dereferencing_class = {}


Expand Down Expand Up @@ -126,46 +129,37 @@ def __exit__(self, t, value, traceback):
self.cls._get_collection_name = self.ori_get_collection_name


class no_dereference:
@contextlib.contextmanager
def no_dereference(cls):
"""no_dereference context manager.
Turns off all dereferencing in Documents for the duration of the context
manager::
with no_dereference(Group):
Group.objects.find()
Group.objects()
"""

def __init__(self, cls):
"""Construct the no_dereference context manager.
:param cls: the class to turn dereferencing off on
"""
self.cls = cls
try:
cls = cls

ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
ComplexBaseField = _import_class("ComplexBaseField")

self.deref_fields = [
k
for k, v in self.cls._fields.items()
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
deref_fields = [
field
for name, field in cls._fields.items()
if isinstance(
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
)
]

def __enter__(self):
"""Change the objects default and _auto_dereference values."""
_register_no_dereferencing_for_class(self.cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = False

def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values."""
_unregister_no_dereferencing_for_class(self.cls)
_register_no_dereferencing_for_class(cls)

for field in self.deref_fields:
self.cls._fields[field]._auto_dereference = True
with _no_dereference_for_fields(*deref_fields):
yield None
finally:
_unregister_no_dereferencing_for_class(cls)


class no_sub_classes:
Expand All @@ -180,7 +174,7 @@ class no_sub_classes:
def __init__(self, cls):
"""Construct the no_sub_classes context manager.
:param cls: the class to turn querying sub classes on
:param cls: the class to turn querying subclasses on
"""
self.cls = cls
self.cls_initial_subclasses = None
Expand Down
16 changes: 8 additions & 8 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import socket
import time
import uuid
from inspect import isclass
from io import BytesIO
from operator import itemgetter

Expand Down Expand Up @@ -707,7 +708,6 @@ class EmbeddedDocumentField(BaseField):
"""

def __init__(self, document_type, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not (
isinstance(document_type, str)
or issubclass(document_type, EmbeddedDocument)
Expand Down Expand Up @@ -910,9 +910,9 @@ class ListField(ComplexBaseField):
Required means it cannot be empty - as the default for ListFields is []
"""

def __init__(self, field=None, max_length=None, **kwargs):
def __init__(self, field=None, *, max_length=None, **kwargs):
self.max_length = max_length
kwargs.setdefault("default", lambda: [])
kwargs.setdefault("default", list)
super().__init__(field=field, **kwargs)

def __get__(self, instance, owner):
Expand Down Expand Up @@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField):
"""

def __init__(self, field=None, *args, **kwargs):
self._auto_dereference = False

kwargs.setdefault("default", lambda: {})
kwargs.setdefault("default", dict)
super().__init__(*args, field=field, **kwargs)
self.set_auto_dereferencing(False)

def validate(self, value):
"""Make sure that a list of valid fields is being used."""
Expand Down Expand Up @@ -1151,8 +1150,9 @@ def __init__(
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(document_type, str) and not issubclass(
document_type, Document
if not (
isinstance(document_type, str)
or (isclass(document_type) and issubclass(document_type, Document))
):
self.error(
"Argument to ReferenceField constructor must be a "
Expand Down
54 changes: 0 additions & 54 deletions python-mongoengine.spec

This file was deleted.

10 changes: 10 additions & 0 deletions tests/fields/test_complex_base_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from mongoengine.base import ComplexBaseField
from tests.utils import MongoDBTestCase


class TestComplexBaseField(MongoDBTestCase):
def test_field_validation(self):
with pytest.raises(TypeError, match="field argument must be a Field instance"):
ComplexBaseField("test")
Loading

0 comments on commit 562293d

Please sign in to comment.