|
| 1 | +import random |
| 2 | +import time |
1 | 3 | import unittest
|
| 4 | +from threading import Thread |
2 | 5 |
|
3 | 6 | import pytest
|
4 | 7 | from bson import DBRef
|
|
18 | 21 | from tests.utils import MongoDBTestCase
|
19 | 22 |
|
20 | 23 |
|
| 24 | +class TestableThread(Thread): |
| 25 | + """ |
| 26 | + Wrapper around `threading.Thread` that propagates exceptions. |
| 27 | +
|
| 28 | + REF: https://gist.github.com/sbrugman/59b3535ebcd5aa0e2598293cfa58b6ab |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, *args, **kwargs): |
| 32 | + super().__init__(*args, **kwargs) |
| 33 | + self.exc = None |
| 34 | + |
| 35 | + def run(self): |
| 36 | + try: |
| 37 | + super().run() |
| 38 | + except BaseException as e: |
| 39 | + self.exc = e |
| 40 | + |
| 41 | + def join(self, timeout=None): |
| 42 | + super().join(timeout) |
| 43 | + if self.exc: |
| 44 | + raise self.exc |
| 45 | + |
| 46 | + |
21 | 47 | class TestContextManagers(MongoDBTestCase):
|
22 | 48 | def test_set_write_concern(self):
|
23 | 49 | class User(Document):
|
@@ -172,13 +198,27 @@ class Group(Document):
|
172 | 198 | group = Group.objects.first()
|
173 | 199 | assert isinstance(group.ref, DBRef)
|
174 | 200 |
|
175 |
| - # make sure its still off here |
| 201 | + # make sure it's still off here |
176 | 202 | group = Group.objects.first()
|
177 | 203 | assert isinstance(group.ref, DBRef)
|
178 | 204 |
|
179 | 205 | group = Group.objects.first()
|
180 | 206 | assert isinstance(group.ref, User)
|
181 | 207 |
|
| 208 | + def run_in_thread(id): |
| 209 | + time.sleep(random.uniform(0.1, 0.5)) # Force desync of threads |
| 210 | + if id % 2 == 0: |
| 211 | + with no_dereference(Group): |
| 212 | + group = Group.objects.first() |
| 213 | + assert isinstance(group.ref, DBRef) |
| 214 | + else: |
| 215 | + group = Group.objects.first() |
| 216 | + assert isinstance(group.ref, User) |
| 217 | + |
| 218 | + threads = [TestableThread(target=run_in_thread, args=(id,)) for id in range(10)] |
| 219 | + _ = [th.start() for th in threads] |
| 220 | + _ = [th.join() for th in threads] |
| 221 | + |
182 | 222 | def test_no_dereference_context_manager_dbref(self):
|
183 | 223 | """Ensure that DBRef items in ListFields aren't dereferenced"""
|
184 | 224 |
|
|
0 commit comments