Skip to content

Commit cfc942f

Browse files
authored
Merge pull request #2830 from bagerard/fix_no_dereference_thread_safetyness
fix_no_dereference_thread_safetyness
2 parents 889096c + e844138 commit cfc942f

13 files changed

+228
-127
lines changed

docs/changelog.rst

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ Development
1010
- Allow gt/gte/lt/lte/ne operators to be used with a list as value on ListField #2813
1111
- Switch tox to use pytest instead of legacy `python setup.py test` #2804
1212
- Add support for timeseries collection #2661
13+
- improve ReferenceField wrong usage detection
14+
- Fix no_dereference thread-safetyness #2830
15+
- BREAKING CHANGE: max_length in ListField is now keyword only on ListField signature
16+
- BREAKING CHANGE: Force `field` argument of ListField/DictField to be a field instance (e.g ListField(StringField()) instead of ListField(StringField)
1317

1418
Changes in 0.28.2
1519
=================

mongoengine/base/document.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False):
806806

807807
fields = cls._fields
808808
if not _auto_dereference:
809+
# if auto_deref is turned off, we copy the fields so
810+
# we can mutate the auto_dereference of the fields
809811
fields = copy.deepcopy(fields)
810812

813+
# Apply field-name / db-field conversion
811814
for field_name, field in fields.items():
812-
field._auto_dereference = _auto_dereference
815+
field.set_auto_dereferencing(
816+
_auto_dereference
817+
) # align the field's auto-dereferencing with the document's
813818
if field.db_field in data:
814819
value = data[field.db_field]
815820
try:

mongoengine/base/fields.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import contextlib
12
import operator
3+
import threading
24
import weakref
35

46
import pymongo
@@ -16,6 +18,19 @@
1618
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
1719

1820

21+
@contextlib.contextmanager
22+
def _no_dereference_for_fields(*fields):
23+
"""Context manager for temporarily disabling a Field's auto-dereferencing
24+
(meant to be used from no_dereference context manager)"""
25+
try:
26+
for field in fields:
27+
field._incr_no_dereference_context()
28+
yield None
29+
finally:
30+
for field in fields:
31+
field._decr_no_dereference_context()
32+
33+
1934
class BaseField:
2035
"""A base class for fields in a MongoDB document. Instances of this class
2136
may be added to subclasses of `Document` to define a document's schema.
@@ -24,7 +39,7 @@ class BaseField:
2439
name = None # set in TopLevelDocumentMetaclass
2540
_geo_index = False
2641
_auto_gen = False # Call `generate` to generate a value
27-
_auto_dereference = True
42+
_thread_local_storage = threading.local()
2843

2944
# These track each time a Field instance is created. Used to retain order.
3045
# The auto_creation_counter is used for fields that MongoEngine implicitly
@@ -85,6 +100,8 @@ def __init__(
85100
self.sparse = sparse
86101
self._owner_document = None
87102

103+
self.__auto_dereference = True
104+
88105
# Make sure db_field is a string (if it's explicitly defined).
89106
if self.db_field is not None and not isinstance(self.db_field, str):
90107
raise TypeError("db_field should be a string.")
@@ -120,6 +137,33 @@ def __init__(
120137
self.creation_counter = BaseField.creation_counter
121138
BaseField.creation_counter += 1
122139

140+
def set_auto_dereferencing(self, value):
141+
self.__auto_dereference = value
142+
143+
@property
144+
def _no_dereference_context_local(self):
145+
if not hasattr(self._thread_local_storage, "no_dereference_context"):
146+
self._thread_local_storage.no_dereference_context = 0
147+
return self._thread_local_storage.no_dereference_context
148+
149+
@property
150+
def _no_dereference_context_is_set(self):
151+
return self._no_dereference_context_local > 0
152+
153+
def _incr_no_dereference_context(self):
154+
self._thread_local_storage.no_dereference_context = (
155+
self._no_dereference_context_local + 1
156+
)
157+
158+
def _decr_no_dereference_context(self):
159+
self._thread_local_storage.no_dereference_context = (
160+
self._no_dereference_context_local - 1
161+
)
162+
163+
@property
164+
def _auto_dereference(self):
165+
return self.__auto_dereference and not self._no_dereference_context_is_set
166+
123167
def __get__(self, instance, owner):
124168
"""Descriptor for retrieving a value from a field in a document."""
125169
if instance is None:
@@ -268,6 +312,10 @@ class ComplexBaseField(BaseField):
268312
"""
269313

270314
def __init__(self, field=None, **kwargs):
315+
if field is not None and not isinstance(field, BaseField):
316+
raise TypeError(
317+
f"field argument must be a Field instance (e.g {self.__class__.__name__}(StringField()))"
318+
)
271319
self.field = field
272320
super().__init__(**kwargs)
273321

@@ -375,7 +423,7 @@ def to_python(self, value):
375423
return value
376424

377425
if self.field:
378-
self.field._auto_dereference = self._auto_dereference
426+
self.field.set_auto_dereferencing(self._auto_dereference)
379427
value_dict = {
380428
key: self.field.to_python(item) for key, item in value.items()
381429
}

mongoengine/context_managers.py

+20-26
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import contextlib
12
import threading
23
from contextlib import contextmanager
34

45
from pymongo.read_concern import ReadConcern
56
from pymongo.write_concern import WriteConcern
67

8+
from mongoengine.base.fields import _no_dereference_for_fields
79
from mongoengine.common import _import_class
810
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
911
from mongoengine.pymongo_support import count_documents
@@ -22,6 +24,7 @@
2224

2325
class MyThreadLocals(threading.local):
2426
def __init__(self):
27+
# {DocCls: count} keeping track of classes with an active no_dereference context
2528
self.no_dereferencing_class = {}
2629

2730

@@ -126,46 +129,37 @@ def __exit__(self, t, value, traceback):
126129
self.cls._get_collection_name = self.ori_get_collection_name
127130

128131

129-
class no_dereference:
132+
@contextlib.contextmanager
133+
def no_dereference(cls):
130134
"""no_dereference context manager.
131135
132136
Turns off all dereferencing in Documents for the duration of the context
133137
manager::
134138
135139
with no_dereference(Group):
136-
Group.objects.find()
140+
Group.objects()
137141
"""
138-
139-
def __init__(self, cls):
140-
"""Construct the no_dereference context manager.
141-
142-
:param cls: the class to turn dereferencing off on
143-
"""
144-
self.cls = cls
142+
try:
143+
cls = cls
145144

146145
ReferenceField = _import_class("ReferenceField")
147146
GenericReferenceField = _import_class("GenericReferenceField")
148147
ComplexBaseField = _import_class("ComplexBaseField")
149148

150-
self.deref_fields = [
151-
k
152-
for k, v in self.cls._fields.items()
153-
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
149+
deref_fields = [
150+
field
151+
for name, field in cls._fields.items()
152+
if isinstance(
153+
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
154+
)
154155
]
155156

156-
def __enter__(self):
157-
"""Change the objects default and _auto_dereference values."""
158-
_register_no_dereferencing_for_class(self.cls)
159-
160-
for field in self.deref_fields:
161-
self.cls._fields[field]._auto_dereference = False
162-
163-
def __exit__(self, t, value, traceback):
164-
"""Reset the default and _auto_dereference values."""
165-
_unregister_no_dereferencing_for_class(self.cls)
157+
_register_no_dereferencing_for_class(cls)
166158

167-
for field in self.deref_fields:
168-
self.cls._fields[field]._auto_dereference = True
159+
with _no_dereference_for_fields(*deref_fields):
160+
yield None
161+
finally:
162+
_unregister_no_dereferencing_for_class(cls)
169163

170164

171165
class no_sub_classes:
@@ -180,7 +174,7 @@ class no_sub_classes:
180174
def __init__(self, cls):
181175
"""Construct the no_sub_classes context manager.
182176
183-
:param cls: the class to turn querying sub classes on
177+
:param cls: the class to turn querying subclasses on
184178
"""
185179
self.cls = cls
186180
self.cls_initial_subclasses = None

mongoengine/fields.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import time
88
import uuid
9+
from inspect import isclass
910
from io import BytesIO
1011
from operator import itemgetter
1112

@@ -707,7 +708,6 @@ class EmbeddedDocumentField(BaseField):
707708
"""
708709

709710
def __init__(self, document_type, **kwargs):
710-
# XXX ValidationError raised outside of the "validate" method.
711711
if not (
712712
isinstance(document_type, str)
713713
or issubclass(document_type, EmbeddedDocument)
@@ -910,9 +910,9 @@ class ListField(ComplexBaseField):
910910
Required means it cannot be empty - as the default for ListFields is []
911911
"""
912912

913-
def __init__(self, field=None, max_length=None, **kwargs):
913+
def __init__(self, field=None, *, max_length=None, **kwargs):
914914
self.max_length = max_length
915-
kwargs.setdefault("default", lambda: [])
915+
kwargs.setdefault("default", list)
916916
super().__init__(field=field, **kwargs)
917917

918918
def __get__(self, instance, owner):
@@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField):
10351035
"""
10361036

10371037
def __init__(self, field=None, *args, **kwargs):
1038-
self._auto_dereference = False
1039-
1040-
kwargs.setdefault("default", lambda: {})
1038+
kwargs.setdefault("default", dict)
10411039
super().__init__(*args, field=field, **kwargs)
1040+
self.set_auto_dereferencing(False)
10421041

10431042
def validate(self, value):
10441043
"""Make sure that a list of valid fields is being used."""
@@ -1151,8 +1150,9 @@ def __init__(
11511150
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
11521151
"""
11531152
# XXX ValidationError raised outside of the "validate" method.
1154-
if not isinstance(document_type, str) and not issubclass(
1155-
document_type, Document
1153+
if not (
1154+
isinstance(document_type, str)
1155+
or (isclass(document_type) and issubclass(document_type, Document))
11561156
):
11571157
self.error(
11581158
"Argument to ReferenceField constructor must be a "

python-mongoengine.spec

-54
This file was deleted.
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
3+
from mongoengine.base import ComplexBaseField
4+
from tests.utils import MongoDBTestCase
5+
6+
7+
class TestComplexBaseField(MongoDBTestCase):
8+
def test_field_validation(self):
9+
with pytest.raises(TypeError, match="field argument must be a Field instance"):
10+
ComplexBaseField("test")

0 commit comments

Comments
 (0)