Skip to content

Commit 1f54671

Browse files
committed
Fix no_dereference thread safetyness + force ComplexField's field arg to be a BaseField instance + make max_length a keyword only arg on ListField
1 parent 6f7f7b7 commit 1f54671

13 files changed

+307
-101
lines changed

docs/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ Development
88
===========
99
- (Fill this out as you fix issues and develop your features).
1010
- Switch tox to use pytest instead of legacy `python setup.py test` #2804
11+
- improve ReferenceField wrong usage detection
12+
- BREAKING CHANGE: max_length in ListField is now keyword only on ListField signature
13+
- BREAKING CHANGE: Force `field` argument of ListField/DictField to be a field instance (e.g ListField(StringField()) instead of ListField(StringField)
1114

1215
Changes in 0.28.2
1316
=================

mongoengine/base/document.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -806,10 +806,15 @@ def _from_son(cls, son, _auto_dereference=True, created=False):
806806

807807
fields = cls._fields
808808
if not _auto_dereference:
809+
# if auto_deref is turned off, we copy the fields so
810+
# we can mutate the auto_dereference of the fields
809811
fields = copy.deepcopy(fields)
810812

813+
# Apply field-name / db-field conversion
811814
for field_name, field in fields.items():
812-
field._auto_dereference = _auto_dereference
815+
field.set_auto_dereferencing(
816+
_auto_dereference
817+
) # align the field's auto-dereferencing with the document's
813818
if field.db_field in data:
814819
value = data[field.db_field]
815820
try:

mongoengine/base/fields.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
import threading
23
import weakref
34

45
import pymongo
@@ -16,6 +17,20 @@
1617
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
1718

1819

20+
class _no_dereference_for_field:
21+
"""Context manager for temporarily disabling a Field's auto-dereferencing
22+
(meant to be used from no_dereference context manager)"""
23+
24+
def __init__(self, field):
25+
self.field = field
26+
27+
def __enter__(self):
28+
self.field._incr_no_dereference_context()
29+
30+
def __exit__(self, exc_type, exc_value, traceback):
31+
self.field._decr_no_dereference_context()
32+
33+
1934
class BaseField:
2035
"""A base class for fields in a MongoDB document. Instances of this class
2136
may be added to subclasses of `Document` to define a document's schema.
@@ -24,7 +39,7 @@ class BaseField:
2439
name = None # set in TopLevelDocumentMetaclass
2540
_geo_index = False
2641
_auto_gen = False # Call `generate` to generate a value
27-
_auto_dereference = True
42+
_thread_local_storage = threading.local()
2843

2944
# These track each time a Field instance is created. Used to retain order.
3045
# The auto_creation_counter is used for fields that MongoEngine implicitly
@@ -85,6 +100,8 @@ def __init__(
85100
self.sparse = sparse
86101
self._owner_document = None
87102

103+
self.__auto_dereference = True
104+
88105
# Make sure db_field is a string (if it's explicitly defined).
89106
if self.db_field is not None and not isinstance(self.db_field, str):
90107
raise TypeError("db_field should be a string.")
@@ -120,6 +137,33 @@ def __init__(
120137
self.creation_counter = BaseField.creation_counter
121138
BaseField.creation_counter += 1
122139

140+
def set_auto_dereferencing(self, value):
141+
self.__auto_dereference = value
142+
143+
@property
144+
def _no_dereference_context_local(self):
145+
if not hasattr(self._thread_local_storage, "no_dereference_context"):
146+
self._thread_local_storage.no_dereference_context = 0
147+
return self._thread_local_storage.no_dereference_context
148+
149+
@property
150+
def _no_dereference_context_is_set(self):
151+
return self._no_dereference_context_local > 0
152+
153+
def _incr_no_dereference_context(self):
154+
self._thread_local_storage.no_dereference_context = (
155+
self._no_dereference_context_local + 1
156+
)
157+
158+
def _decr_no_dereference_context(self):
159+
self._thread_local_storage.no_dereference_context = (
160+
self._no_dereference_context_local - 1
161+
)
162+
163+
@property
164+
def _auto_dereference(self):
165+
return self.__auto_dereference and not self._no_dereference_context_is_set
166+
123167
def __get__(self, instance, owner):
124168
"""Descriptor for retrieving a value from a field in a document."""
125169
if instance is None:
@@ -268,6 +312,10 @@ class ComplexBaseField(BaseField):
268312
"""
269313

270314
def __init__(self, field=None, **kwargs):
315+
if field is not None and not isinstance(field, BaseField):
316+
raise TypeError(
317+
f"field argument must be a Field instance (e.g {self.__class__.__name__}(StringField()))"
318+
)
271319
self.field = field
272320
super().__init__(**kwargs)
273321

@@ -375,7 +423,7 @@ def to_python(self, value):
375423
return value
376424

377425
if self.field:
378-
self.field._auto_dereference = self._auto_dereference
426+
self.field.set_auto_dereferencing(self._auto_dereference)
379427
value_dict = {
380428
key: self.field.to_python(item) for key, item in value.items()
381429
}

mongoengine/context_managers.py

+63-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import contextlib
12
import threading
23
from contextlib import contextmanager
34

45
from pymongo.read_concern import ReadConcern
56
from pymongo.write_concern import WriteConcern
67

8+
from mongoengine.base.fields import _no_dereference_for_field
79
from mongoengine.common import _import_class
810
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
911
from mongoengine.pymongo_support import count_documents
@@ -22,6 +24,7 @@
2224

2325
class MyThreadLocals(threading.local):
2426
def __init__(self):
27+
# {DocCls: count} keeping track of classes with an active no_dereference context
2528
self.no_dereferencing_class = {}
2629

2730

@@ -126,14 +129,55 @@ def __exit__(self, t, value, traceback):
126129
self.cls._get_collection_name = self.ori_get_collection_name
127130

128131

129-
class no_dereference:
132+
@contextlib.contextmanager
133+
def no_dereference(cls):
130134
"""no_dereference context manager.
131135
132136
Turns off all dereferencing in Documents for the duration of the context
133137
manager::
134138
135139
with no_dereference(Group):
136-
Group.objects.find()
140+
Group.objects()
141+
"""
142+
try:
143+
cls = cls
144+
145+
ReferenceField = _import_class("ReferenceField")
146+
GenericReferenceField = _import_class("GenericReferenceField")
147+
ComplexBaseField = _import_class("ComplexBaseField")
148+
149+
deref_fields = [
150+
field
151+
for name, field in cls._fields.items()
152+
if isinstance(
153+
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
154+
)
155+
]
156+
no_deref_for_fields_contexts = [
157+
_no_dereference_for_field(field) for field in deref_fields
158+
]
159+
160+
_register_no_dereferencing_for_class(cls)
161+
162+
# ExitStack is just a fancy way of nesting multiple context managers into 1
163+
with contextlib.ExitStack() as stack:
164+
for mgr in no_deref_for_fields_contexts:
165+
stack.enter_context(mgr)
166+
167+
yield None
168+
169+
finally:
170+
_unregister_no_dereferencing_for_class(cls)
171+
172+
173+
class no_dereference2:
174+
"""no_dereference context manager.
175+
176+
Turns off all dereferencing in Documents for the duration of the context
177+
manager::
178+
179+
with no_dereference(Group):
180+
Group.objects()
137181
"""
138182

139183
def __init__(self, cls):
@@ -148,24 +192,33 @@ def __init__(self, cls):
148192
ComplexBaseField = _import_class("ComplexBaseField")
149193

150194
self.deref_fields = [
151-
k
152-
for k, v in self.cls._fields.items()
153-
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
195+
field
196+
for name, field in self.cls._fields.items()
197+
if isinstance(
198+
field, (ReferenceField, GenericReferenceField, ComplexBaseField)
199+
)
200+
]
201+
self.no_deref_for_fields_contexts = [
202+
_no_dereference_for_field(field) for field in self.deref_fields
154203
]
155204

156205
def __enter__(self):
157206
"""Change the objects default and _auto_dereference values."""
158207
_register_no_dereferencing_for_class(self.cls)
159208

160-
for field in self.deref_fields:
161-
self.cls._fields[field]._auto_dereference = False
209+
for ndff_context in self.no_deref_for_fields_contexts:
210+
ndff_context.__enter__()
211+
# for field in self.deref_fields:
212+
# self.cls._fields[field]._auto_dereference = False
162213

163214
def __exit__(self, t, value, traceback):
164215
"""Reset the default and _auto_dereference values."""
165216
_unregister_no_dereferencing_for_class(self.cls)
166217

167-
for field in self.deref_fields:
168-
self.cls._fields[field]._auto_dereference = True
218+
for ndff_context in self.no_deref_for_fields_contexts:
219+
ndff_context.__exit__(t, value, traceback)
220+
# for field in self.deref_fields:
221+
# self.cls._fields[field]._auto_dereference = True # should set initial values back
169222

170223

171224
class no_sub_classes:
@@ -180,7 +233,7 @@ class no_sub_classes:
180233
def __init__(self, cls):
181234
"""Construct the no_sub_classes context manager.
182235
183-
:param cls: the class to turn querying sub classes on
236+
:param cls: the class to turn querying subclasses on
184237
"""
185238
self.cls = cls
186239
self.cls_initial_subclasses = None

mongoengine/fields.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import time
88
import uuid
9+
from inspect import isclass
910
from io import BytesIO
1011
from operator import itemgetter
1112

@@ -707,7 +708,6 @@ class EmbeddedDocumentField(BaseField):
707708
"""
708709

709710
def __init__(self, document_type, **kwargs):
710-
# XXX ValidationError raised outside of the "validate" method.
711711
if not (
712712
isinstance(document_type, str)
713713
or issubclass(document_type, EmbeddedDocument)
@@ -910,9 +910,9 @@ class ListField(ComplexBaseField):
910910
Required means it cannot be empty - as the default for ListFields is []
911911
"""
912912

913-
def __init__(self, field=None, max_length=None, **kwargs):
913+
def __init__(self, field=None, *, max_length=None, **kwargs):
914914
self.max_length = max_length
915-
kwargs.setdefault("default", lambda: [])
915+
kwargs.setdefault("default", list)
916916
super().__init__(field=field, **kwargs)
917917

918918
def __get__(self, instance, owner):
@@ -1035,10 +1035,9 @@ class DictField(ComplexBaseField):
10351035
"""
10361036

10371037
def __init__(self, field=None, *args, **kwargs):
1038-
self._auto_dereference = False
1039-
1040-
kwargs.setdefault("default", lambda: {})
1038+
kwargs.setdefault("default", dict)
10411039
super().__init__(*args, field=field, **kwargs)
1040+
self.set_auto_dereferencing(False)
10421041

10431042
def validate(self, value):
10441043
"""Make sure that a list of valid fields is being used."""
@@ -1151,8 +1150,9 @@ def __init__(
11511150
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
11521151
"""
11531152
# XXX ValidationError raised outside of the "validate" method.
1154-
if not isinstance(document_type, str) and not issubclass(
1155-
document_type, Document
1153+
if not (
1154+
isinstance(document_type, str)
1155+
or (isclass(document_type) and issubclass(document_type, Document))
11561156
):
11571157
self.error(
11581158
"Argument to ReferenceField constructor must be a "

python-mongoengine.spec

-54
This file was deleted.
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
3+
from mongoengine.base import ComplexBaseField
4+
from tests.utils import MongoDBTestCase
5+
6+
7+
class TestComplexBaseField(MongoDBTestCase):
8+
def test_field_validation(self):
9+
with pytest.raises(TypeError, match="field argument must be a Field instance"):
10+
ComplexBaseField("test")

0 commit comments

Comments
 (0)