forked from python-trio/pytest-trio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplugin.py
143 lines (111 loc) · 4.71 KB
/
plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""pytest-trio implementation."""
import contextlib
import inspect
import socket
from functools import partial
import pytest
import trio
def pytest_configure(config):
"""Inject documentation."""
config.addinivalue_line("markers",
"trio: "
"mark the test as a coroutine, it will be "
"run using an asyncio event loop")
def _trio_test_runner_factory(item, clock):
testfunc = item.function
async def _bootstrap_fixture_and_run_test(**kwargs):
kwargs = await _resolve_coroutine_fixtures_in(kwargs)
await testfunc(**kwargs)
def run_test_in_trio(**kwargs):
trio._core.run(partial(_bootstrap_fixture_and_run_test, **kwargs), clock=clock)
return run_test_in_trio
async def _resolve_coroutine_fixtures_in(deps):
resolved_deps = {**deps}
async def _resolve_and_update_deps(afunc, deps, entry):
deps[entry] = await afunc()
async with trio.open_nursery() as nursery:
for depname, depval in resolved_deps.items():
if isinstance(depval, CoroutineFixture):
nursery.start_soon(
_resolve_and_update_deps, depval.resolve, resolved_deps, depname)
return resolved_deps
class CoroutineFixture:
"""
Represent a fixture that need to be run in a trio context to be resolved.
Can be async function fixture or a syncronous fixture with async
dependencies fixtures.
"""
NOTSET = object()
def __init__(self, fixturefunc, fixturedef, deps={}):
self.fixturefunc = fixturefunc
# Note fixturedef.func
self.fixturedef = fixturedef
self.deps = deps
self._ret = self.NOTSET
async def resolve(self):
if self._ret is self.NOTSET:
resolved_deps = await _resolve_coroutine_fixtures_in(self.deps)
if inspect.iscoroutinefunction(self.fixturefunc):
self._ret = await self.fixturefunc(**resolved_deps)
else:
self._ret = self.fixturefunc(**resolved_deps)
return self._ret
def _install_coroutine_fixture_if_needed(fixturedef, request):
deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames}
corofix = None
if not deps and inspect.iscoroutinefunction(fixturedef.func):
# Top level async coroutine
corofix = CoroutineFixture(fixturedef.func, fixturedef)
elif any(dep for dep in deps.values() if isinstance(dep, CoroutineFixture)):
# Fixture with coroutine fixture dependencies
corofix = CoroutineFixture(fixturedef.func, fixturedef, deps)
# The coroutine fixture must be evaluated from within the trio context
# which is spawed in the function test's trio decorator.
# The trick is to make pytest's fixture call return the CoroutineFixture
# object which will be actully resolved just before we run the test.
if corofix:
fixturedef.func = lambda **kwargs: corofix
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
if 'trio' in request.keywords:
_install_coroutine_fixture_if_needed(fixturedef, request)
@pytest.hookimpl(tryfirst=True)
def pytest_collection_modifyitems(session, config, items):
# Retrieve test marked as `trio`
for item in items:
if 'trio' not in item.keywords:
continue
if not inspect.iscoroutinefunction(item.function):
pytest.fail('test function `%r` is marked trio but is not async' % item)
# Extract the clock fixture if provided
clocks = [c for c in item.funcargs.values() if isinstance(c, trio.abc.Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
clock = clocks[0]
else:
raise pytest.fail("too many clocks spoil the broth!")
item.obj = _trio_test_runner_factory(item, clock)
@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact(node, call, report):
if issubclass(call.excinfo.type, trio.MultiError):
# TODO: not really elegant (pytest cannot output color with this hack)
report.longrepr = ''.join(trio.format_exception(*call.excinfo._excinfo))
@pytest.fixture
def unused_tcp_port():
"""Find an unused localhost TCP port from 1024-65535 and return it."""
with contextlib.closing(socket.socket()) as sock:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]
@pytest.fixture
def unused_tcp_port_factory():
"""A factory function, producing different unused TCP ports."""
produced = set()
def factory():
"""Return an unused port."""
port = unused_tcp_port()
while port in produced:
port = unused_tcp_port()
produced.add(port)
return port
return factory