Skip to content

Commit 477b237

Browse files
authored
Fix websocket connection leak (#7978) (#7980)
1 parent f1cee99 commit 477b237

File tree

4 files changed

+93
-41
lines changed

4 files changed

+93
-41
lines changed

CHANGES/7978.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix websocket connection leak

aiohttp/web_ws.py

+54-40
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,8 @@ def _send_heartbeat(self) -> None:
140140
def _pong_not_received(self) -> None:
141141
if self._req is not None and self._req.transport is not None:
142142
self._closed = True
143-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
143+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
144144
self._exception = asyncio.TimeoutError()
145-
self._req.transport.close()
146145

147146
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
148147
# make pre-check to don't hide it by do_handshake() exceptions
@@ -360,7 +359,10 @@ async def write_eof(self) -> None: # type: ignore[override]
360359
await self.close()
361360
self._eof_sent = True
362361

363-
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
362+
async def close(
363+
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
364+
) -> bool:
365+
"""Close websocket connection."""
364366
if self._writer is None:
365367
raise RuntimeError("Call .prepare() first")
366368

@@ -374,46 +376,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
374376
reader.feed_data(WS_CLOSING_MESSAGE, 0)
375377
await self._waiting
376378

377-
if not self._closed:
378-
self._closed = True
379-
try:
380-
await self._writer.close(code, message)
381-
writer = self._payload_writer
382-
assert writer is not None
383-
await writer.drain()
384-
except (asyncio.CancelledError, asyncio.TimeoutError):
385-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
386-
raise
387-
except Exception as exc:
388-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
389-
self._exception = exc
390-
return True
379+
if self._closed:
380+
return False
391381

392-
if self._closing:
393-
return True
382+
self._closed = True
383+
try:
384+
await self._writer.close(code, message)
385+
writer = self._payload_writer
386+
assert writer is not None
387+
if drain:
388+
await writer.drain()
389+
except (asyncio.CancelledError, asyncio.TimeoutError):
390+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
391+
raise
392+
except Exception as exc:
393+
self._exception = exc
394+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
395+
return True
394396

395-
reader = self._reader
396-
assert reader is not None
397-
try:
398-
async with async_timeout.timeout(self._timeout):
399-
msg = await reader.read()
400-
except asyncio.CancelledError:
401-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
402-
raise
403-
except Exception as exc:
404-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
405-
self._exception = exc
406-
return True
397+
if self._closing:
398+
return True
407399

408-
if msg.type == WSMsgType.CLOSE:
409-
self._close_code = msg.data
410-
return True
400+
reader = self._reader
401+
assert reader is not None
402+
try:
403+
async with async_timeout.timeout(self._timeout):
404+
msg = await reader.read()
405+
except asyncio.CancelledError:
406+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
407+
raise
408+
except Exception as exc:
409+
self._exception = exc
410+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
411+
return True
411412

412-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
413-
self._exception = asyncio.TimeoutError()
413+
if msg.type == WSMsgType.CLOSE:
414+
self._set_code_close_transport(msg.data)
414415
return True
415-
else:
416-
return False
416+
417+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
418+
self._exception = asyncio.TimeoutError()
419+
return True
420+
421+
def _set_code_close_transport(self, code: WSCloseCode) -> None:
422+
"""Set the close code and close the transport."""
423+
self._close_code = code
424+
if self._req is not None and self._req.transport is not None:
425+
self._req.transport.close()
417426

418427
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
419428
if self._reader is None:
@@ -444,7 +453,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
444453
set_result(waiter, True)
445454
self._waiting = None
446455
except (asyncio.CancelledError, asyncio.TimeoutError):
447-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
456+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
448457
raise
449458
except EofStream:
450459
self._close_code = WSCloseCode.OK
@@ -464,8 +473,13 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
464473
if msg.type == WSMsgType.CLOSE:
465474
self._closing = True
466475
self._close_code = msg.data
476+
# Could be closed while awaiting reader.
467477
if not self._closed and self._autoclose:
468-
await self.close()
478+
# The client is likely going to close the
479+
# connection out from under us so we do not
480+
# want to drain any pending writes as it will
481+
# likely result writing to a broken pipe.
482+
await self.close(drain=False)
469483
elif msg.type == WSMsgType.CLOSING:
470484
self._closing = True
471485
elif msg.type == WSMsgType.PING and self._autoping:

docs/web_reference.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,14 @@ and :ref:`aiohttp-web-signals` handlers::
970970

971971
.. versionadded:: 3.3
972972

973+
:param bool autoclose: Close connection when the client sends
974+
a :const:`~aiohttp.WSMsgType.CLOSE` message,
975+
``True`` by default. If set to ``False``,
976+
the connection is not closed and the
977+
caller is responsible for calling
978+
``request.transport.close()`` to avoid
979+
leaking resources.
980+
973981

974982
The class supports ``async for`` statement for iterating over
975983
incoming messages::
@@ -1146,7 +1154,7 @@ and :ref:`aiohttp-web-signals` handlers::
11461154
The method is converted into :term:`coroutine`,
11471155
*compress* parameter added.
11481156

1149-
.. method:: close(*, code=WSCloseCode.OK, message=b'')
1157+
.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
11501158
:async:
11511159

11521160
A :ref:`coroutine<coroutine>` that initiates closing
@@ -1160,6 +1168,8 @@ and :ref:`aiohttp-web-signals` handlers::
11601168
:class:`str` (converted to *UTF-8* encoded bytes)
11611169
or :class:`bytes`.
11621170

1171+
:param bool drain: drain outgoing buffer before closing connection.
1172+
11631173
:raise RuntimeError: if connection is not started
11641174

11651175
.. method:: receive(timeout=None)

tests/test_web_websocket.py

+27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
import time
3+
from typing import Any
24
from unittest import mock
35

46
import aiosignal
@@ -165,6 +167,20 @@ async def test_write_non_prepared() -> None:
165167
await ws.write(b"data")
166168

167169

170+
async def test_heartbeat_timeout(make_request: Any) -> None:
171+
"""Verify the transport is closed when the heartbeat timeout is reached."""
172+
loop = asyncio.get_running_loop()
173+
future = loop.create_future()
174+
req = make_request("GET", "/")
175+
lowest_time = time.get_clock_info("monotonic").resolution
176+
req._protocol._timeout_ceil_threshold = lowest_time
177+
ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time)
178+
await ws.prepare(req)
179+
ws._req.transport.close.side_effect = lambda: future.set_result(None)
180+
await future
181+
assert ws.closed
182+
183+
168184
def test_websocket_ready() -> None:
169185
websocket_ready = WebSocketReady(True, "chat")
170186
assert websocket_ready.ok is True
@@ -233,6 +249,7 @@ async def test_send_str_closed(make_request) -> None:
233249
await ws.prepare(req)
234250
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
235251
await ws.close()
252+
assert len(ws._req.transport.close.mock_calls) == 1
236253

237254
with pytest.raises(ConnectionError):
238255
await ws.send_str("string")
@@ -289,6 +306,8 @@ async def test_close_idempotent(make_request) -> None:
289306
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
290307
assert await ws.close(code=1, message="message1")
291308
assert ws.closed
309+
assert len(ws._req.transport.close.mock_calls) == 1
310+
292311
assert not (await ws.close(code=2, message="message2"))
293312

294313

@@ -322,12 +341,15 @@ async def test_write_eof_idempotent(make_request) -> None:
322341
req = make_request("GET", "/")
323342
ws = WebSocketResponse()
324343
await ws.prepare(req)
344+
assert len(ws._req.transport.close.mock_calls) == 0
345+
325346
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
326347
await ws.close()
327348

328349
await ws.write_eof()
329350
await ws.write_eof()
330351
await ws.write_eof()
352+
assert len(ws._req.transport.close.mock_calls) == 1
331353

332354

333355
async def test_receive_eofstream_in_reader(make_request, loop) -> None:
@@ -353,6 +375,7 @@ async def test_receive_timeouterror(make_request, loop) -> None:
353375
req = make_request("GET", "/")
354376
ws = WebSocketResponse()
355377
await ws.prepare(req)
378+
assert len(ws._req.transport.close.mock_calls) == 0
356379

357380
ws._reader = mock.Mock()
358381
res = loop.create_future()
@@ -362,6 +385,8 @@ async def test_receive_timeouterror(make_request, loop) -> None:
362385
with pytest.raises(asyncio.TimeoutError):
363386
await ws.receive()
364387

388+
assert len(ws._req.transport.close.mock_calls) == 1
389+
365390

366391
async def test_multiple_receive_on_close_connection(make_request) -> None:
367392
req = make_request("GET", "/")
@@ -394,13 +419,15 @@ async def test_close_exc(make_request) -> None:
394419
req = make_request("GET", "/")
395420
ws = WebSocketResponse()
396421
await ws.prepare(req)
422+
assert len(ws._req.transport.close.mock_calls) == 0
397423

398424
exc = ValueError()
399425
ws._writer = mock.Mock()
400426
ws._writer.close.side_effect = exc
401427
await ws.close()
402428
assert ws.closed
403429
assert ws.exception() is exc
430+
assert len(ws._req.transport.close.mock_calls) == 1
404431

405432
ws._closed = False
406433
ws._writer.close.side_effect = asyncio.CancelledError()

0 commit comments

Comments
 (0)