Skip to content

Commit

Permalink
Support strict_exception_groups=True (#188)
Browse files Browse the repository at this point in the history
Fixes #132 and #187

* Changes `open_websocket` to only raise a single exception, even when
running under `strict_exception_groups=True`
  * [ ] Should maybe introduce special handling for `KeyboardInterrupt`s
* If multiple non-Cancelled-exceptions are encountered, then it will
raise `TrioWebSocketInternalError` with the exceptiongroup as its
`__cause__`. This should only be possible if the background task and the
user context both raise exceptions. This would previously raise a
`MultiError` with both Exceptions.
* other alternatives could include throwing out the exception from the
background task, raising an ExceptionGroup with both errors, or trying
to do something fancy with `__cause__` or `__context__`.
* `WebSocketServer.run` and `WebSocketServer._handle_connection` are the
other two places that opens a nursery. I've opted not to change these,
since I don't think user code should expect any special exceptions from
it, and it seems less obscure that it might contain an internal nursery.
  * [ ] Update docstrings to mention existence of internal nursery.
  • Loading branch information
jakkdl authored Sep 21, 2024
1 parent 3294748 commit f5fd6d7
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ['3.12']
python-version: ['3.13-dev']
steps:
- uses: actions/checkout@v3
- name: Setup Python
Expand Down
5 changes: 2 additions & 3 deletions requirements-dev-full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ build==1.2.1
# via pip-tools
certifi==2024.6.2
# via requests
cffi==1.16.0
cffi==1.17.0
# via cryptography
charset-normalizer==3.3.2
# via requests
Expand Down Expand Up @@ -194,9 +194,8 @@ tomli==2.0.1
# pytest
tomlkit==0.12.5
# via pylint
trio==0.24.0
trio==0.25.1
# via
# -r requirements-dev.in
# pytest-trio
# trio-websocket (setup.py)
trustme==1.1.0
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ pip-tools>=5.5.0
pytest>=4.6
pytest-cov
pytest-trio>=0.5.0
trio<0.25
trustme
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ tomli==2.0.1
# coverage
# pip-tools
# pytest
trio==0.24.0
trio==0.25.1
# via
# -r requirements-dev.in
# pytest-trio
# trio-websocket (setup.py)
trustme==1.1.0
Expand Down
117 changes: 116 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from __future__ import annotations

from functools import partial, wraps
import re
import ssl
import sys
from unittest.mock import patch

import attr
Expand All @@ -48,6 +50,13 @@
except ImportError:
from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports


# only available on trio>=0.25, we don't use it when testing lower versions
try:
from trio.testing import RaisesGroup
except ImportError:
pass

from trio_websocket import (
connect_websocket,
connect_websocket_url,
Expand All @@ -60,12 +69,18 @@
open_websocket,
open_websocket_url,
serve_websocket,
WebSocketConnection,
WebSocketServer,
WebSocketRequest,
wrap_client_stream,
wrap_server_stream
)

from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin

WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.')))

HOST = '127.0.0.1'
Expand Down Expand Up @@ -428,6 +443,92 @@ async def handler(request):
assert header_value == b'My test header'




@fail_after(5)
async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
"""_reader_task._handle_ping_event triggers KeyboardInterrupt.
user code also raises exception.
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
"""
async def ki_raising_ping_handler(*args, **kwargs) -> None:
print("raising ki")
raise KeyboardInterrupt
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with pytest.raises(KeyboardInterrupt) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
with trio.fail_after(1) as cs:
cs.shield = True
await trio.sleep(2)

e_cause = exc_info.value.__cause__
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions)

@fail_after(5)
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
"""_reader_task._handle_ping_event triggers ValueError.
user code also raises exception.
internal exception is in __cause__ exceptiongroup and user exc is delivered
"""
my_value_error = ValueError()
async def raising_ping_event(*args, **kwargs) -> None:
raise my_value_error

monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with pytest.raises(trio.TooSlowError) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
with trio.fail_after(1) as cs:
cs.shield = True
await trio.sleep(2)

e_cause = exc_info.value.__cause__
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
assert my_value_error in e_cause.exceptions

@fail_after(5)
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
"""Both user code and _reader_task raise Cancellation.
Check that open_websocket reraises the one from user code for traceback reasons.
"""


async def sleeping_ping_event(*args, **kwargs) -> None:
await trio.sleep_forever()

# We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually
# raise Cancelled upon being cancelled. For some reason it doesn't otherwise.
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event)
async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")
user_cancelled = None

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with trio.move_on_after(2):
with pytest.raises(trio.Cancelled) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
try:
await trio.sleep_forever()
except trio.Cancelled as e:
user_cancelled = e
raise
assert exc_info.value is user_cancelled

def _trio_default_non_strict_exception_groups() -> bool:
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
return int(trio.__version__[2:4]) < 25

@fail_after(1)
async def test_handshake_exception_before_accept() -> None:
''' In #107, a request handler that throws an exception before finishing the
Expand All @@ -436,14 +537,28 @@ async def test_handshake_exception_before_accept() -> None:
async def handler(request):
raise ValueError()

with pytest.raises(ValueError):
# pylint fails to resolve that BaseExceptionGroup will always be available
with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment
async with trio.open_nursery() as nursery:
server = await nursery.start(serve_websocket, handler, HOST, 0,
None)
async with open_websocket(HOST, server.port, RESOURCE,
use_ssl=False):
pass

if _trio_default_non_strict_exception_groups():
assert isinstance(exc.value, ValueError)
else:
# there's 4 levels of nurseries opened, leading to 4 nested groups:
# 1. this test
# 2. WebSocketServer.run
# 3. trio.serve_listeners
# 4. WebSocketServer._handle_connection
assert RaisesGroup(
RaisesGroup(
RaisesGroup(
RaisesGroup(ValueError)))).matches(exc.value)


@fail_after(1)
async def test_reject_handshake(nursery):
Expand Down
129 changes: 118 additions & 11 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import urllib.parse
from typing import Iterable, List, Optional, Union

import outcome
import trio
import trio.abc
from wsproto import ConnectionType, WSConnection
Expand All @@ -35,7 +36,12 @@
# pylint doesn't care about the version_info check, so need to ignore the warning
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin

_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22)
_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22)

if _IS_TRIO_MULTI_ERROR:
_TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member
else:
_TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment

CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
MESSAGE_QUEUE_SIZE = 1
Expand All @@ -44,6 +50,13 @@
logger = logging.getLogger('trio-websocket')


class TrioWebsocketInternalError(Exception):
"""Raised as a fallback when open_websocket is unable to unwind an exceptiongroup
into a single preferred exception. This should never happen, if it does then
underlying assumptions about the internal code are incorrect.
"""


def _ignore_cancel(exc):
return None if isinstance(exc, trio.Cancelled) else exc

Expand All @@ -70,7 +83,7 @@ def __exit__(self, ty, value, tb):
if value is None or not self._armed:
return False

if _TRIO_MULTI_ERROR: # pragma: no cover
if _IS_TRIO_MULTI_ERROR: # pragma: no cover
filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member
elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment
filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled))
Expand Down Expand Up @@ -125,10 +138,33 @@ async def open_websocket(
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
or server rejection (:exc:`ConnectionRejected`) during handshakes.
'''
async with trio.open_nursery() as new_nursery:

# This context manager tries very very hard not to raise an exceptiongroup
# in order to be as transparent as possible for the end user.
# In the trivial case, this means that if user code inside the cm raises
# we make sure that it doesn't get wrapped.

# If opening the connection fails, then we will raise that exception. User
# code is never executed, so we will never have multiple exceptions.

# After opening the connection, we spawn _reader_task in the background and
# yield to user code. If only one of those raise a non-cancelled exception
# we will raise that non-cancelled exception.
# If we get multiple cancelled, we raise the user's cancelled.
# If both raise exceptions, we raise the user code's exception with the entire
# exception group as the __cause__.
# If we somehow get multiple exceptions, but no user exception, then we raise
# TrioWebsocketInternalError.

# If closing the connection fails, then that will be raised as the top
# exception in the last `finally`. If we encountered exceptions in user code
# or in reader task then they will be set as the `__cause__`.


async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
try:
with trio.fail_after(connect_timeout):
connection = await connect_websocket(new_nursery, host, port,
return await connect_websocket(nursery, host, port,
resource, use_ssl=use_ssl, subprotocols=subprotocols,
extra_headers=extra_headers,
message_queue_size=message_queue_size,
Expand All @@ -137,14 +173,85 @@ async def open_websocket(
raise ConnectionTimeout from None
except OSError as e:
raise HandshakeError from e

async def _close_connection(connection: WebSocketConnection) -> None:
try:
yield connection
finally:
try:
with trio.fail_after(disconnect_timeout):
await connection.aclose()
except trio.TooSlowError:
raise DisconnectionTimeout from None
with trio.fail_after(disconnect_timeout):
await connection.aclose()
except trio.TooSlowError:
raise DisconnectionTimeout from None

connection: WebSocketConnection|None=None
close_result: outcome.Maybe[None] | None = None
user_error = None

try:
async with trio.open_nursery() as new_nursery:
result = await outcome.acapture(_open_connection, new_nursery)

if isinstance(result, outcome.Value):
connection = result.unwrap()
try:
yield connection
except BaseException as e:
user_error = e
raise
finally:
close_result = await outcome.acapture(_close_connection, connection)
# This exception handler should only be entered if either:
# 1. The _reader_task started in connect_websocket raises
# 2. User code raises an exception
# I.e. open/close_connection are not included
except _TRIO_EXC_GROUP_TYPE as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]

# contains at most 1 non-cancelled exceptions
exception_to_raise: BaseException|None = None
for sub_exc in e.exceptions:
if not isinstance(sub_exc, trio.Cancelled):
if exception_to_raise is not None:
# multiple non-cancelled
break
exception_to_raise = sub_exc
else:
if exception_to_raise is None:
# all exceptions are cancelled
# prefer raising the one from the user, for traceback reasons
if user_error is not None:
# no reason to raise from e, just to include a bunch of extra
# cancelleds.
raise user_error # pylint: disable=raise-missing-from
# multiple internal Cancelled is not possible afaik
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
raise exception_to_raise

# if we have any KeyboardInterrupt in the group, make sure to raise it.
for sub_exc in e.exceptions:
if isinstance(sub_exc, KeyboardInterrupt):
raise sub_exc from e

# Both user code and internal code raised non-cancelled exceptions.
# We "hide" the internal exception(s) in the __cause__ and surface
# the user_error.
if user_error is not None:
raise user_error from e

raise TrioWebsocketInternalError(
"The trio-websocket API is not expected to raise multiple exceptions. "
"Please report this as a bug to "
"https://github.com/python-trio/trio-websocket"
) from e # pragma: no cover

finally:
if close_result is not None:
close_result.unwrap()


# error setting up, unwrap that exception
if connection is None:
result.unwrap()


async def connect_websocket(nursery, host, port, resource, *, use_ssl,
Expand Down

0 comments on commit f5fd6d7

Please sign in to comment.