Skip to content

Commit 6f1c608

Browse files
authored
Fix websocket connection leak (#7978)
1 parent 5e44ba4 commit 6f1c608

File tree

4 files changed

+91
-41
lines changed

4 files changed

+91
-41
lines changed

CHANGES/7978.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix websocket connection leak

aiohttp/web_ws.py

+53-40
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,8 @@ def _send_heartbeat(self) -> None:
162162
def _pong_not_received(self) -> None:
163163
if self._req is not None and self._req.transport is not None:
164164
self._closed = True
165-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
165+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
166166
self._exception = asyncio.TimeoutError()
167-
self._req.transport.close()
168167

169168
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
170169
# make pre-check to don't hide it by do_handshake() exceptions
@@ -382,7 +381,10 @@ async def write_eof(self) -> None: # type: ignore[override]
382381
await self.close()
383382
self._eof_sent = True
384383

385-
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
384+
async def close(
385+
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
386+
) -> bool:
387+
"""Close websocket connection."""
386388
if self._writer is None:
387389
raise RuntimeError("Call .prepare() first")
388390

@@ -396,46 +398,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
396398
reader.feed_data(WS_CLOSING_MESSAGE, 0)
397399
await self._waiting
398400

399-
if not self._closed:
400-
self._closed = True
401-
try:
402-
await self._writer.close(code, message)
403-
writer = self._payload_writer
404-
assert writer is not None
405-
await writer.drain()
406-
except (asyncio.CancelledError, asyncio.TimeoutError):
407-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
408-
raise
409-
except Exception as exc:
410-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
411-
self._exception = exc
412-
return True
401+
if self._closed:
402+
return False
413403

414-
if self._closing:
415-
return True
404+
self._closed = True
405+
try:
406+
await self._writer.close(code, message)
407+
writer = self._payload_writer
408+
assert writer is not None
409+
if drain:
410+
await writer.drain()
411+
except (asyncio.CancelledError, asyncio.TimeoutError):
412+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
413+
raise
414+
except Exception as exc:
415+
self._exception = exc
416+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
417+
return True
416418

417-
reader = self._reader
418-
assert reader is not None
419-
try:
420-
async with async_timeout.timeout(self._timeout):
421-
msg = await reader.read()
422-
except asyncio.CancelledError:
423-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
424-
raise
425-
except Exception as exc:
426-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
427-
self._exception = exc
428-
return True
419+
if self._closing:
420+
return True
429421

430-
if msg.type == WSMsgType.CLOSE:
431-
self._close_code = msg.data
432-
return True
422+
reader = self._reader
423+
assert reader is not None
424+
try:
425+
async with async_timeout.timeout(self._timeout):
426+
msg = await reader.read()
427+
except asyncio.CancelledError:
428+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
429+
raise
430+
except Exception as exc:
431+
self._exception = exc
432+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
433+
return True
433434

434-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
435-
self._exception = asyncio.TimeoutError()
435+
if msg.type == WSMsgType.CLOSE:
436+
self._set_code_close_transport(msg.data)
436437
return True
437-
else:
438-
return False
438+
439+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
440+
self._exception = asyncio.TimeoutError()
441+
return True
442+
443+
def _set_code_close_transport(self, code: WSCloseCode) -> None:
444+
"""Set the close code and close the transport."""
445+
self._close_code = code
446+
if self._req is not None and self._req.transport is not None:
447+
self._req.transport.close()
439448

440449
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
441450
if self._reader is None:
@@ -466,7 +475,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
466475
set_result(waiter, True)
467476
self._waiting = None
468477
except (asyncio.CancelledError, asyncio.TimeoutError):
469-
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
478+
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
470479
raise
471480
except EofStream:
472481
self._close_code = WSCloseCode.OK
@@ -488,7 +497,11 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
488497
self._close_code = msg.data
489498
# Could be closed while awaiting reader.
490499
if not self._closed and self._autoclose: # type: ignore[redundant-expr]
491-
await self.close()
500+
# The client is likely going to close the
501+
# connection out from under us so we do not
502+
# want to drain any pending writes as it will
503+
# likely result writing to a broken pipe.
504+
await self.close(drain=False)
492505
elif msg.type == WSMsgType.CLOSING:
493506
self._closing = True
494507
elif msg.type == WSMsgType.PING and self._autoping:

docs/web_reference.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,14 @@ and :ref:`aiohttp-web-signals` handlers::
988988

989989
.. versionadded:: 3.3
990990

991+
:param bool autoclose: Close connection when the client sends
992+
a :const:`~aiohttp.WSMsgType.CLOSE` message,
993+
``True`` by default. If set to ``False``,
994+
the connection is not closed and the
995+
caller is responsible for calling
996+
``request.transport.close()`` to avoid
997+
leaking resources.
998+
991999

9921000
The class supports ``async for`` statement for iterating over
9931001
incoming messages::
@@ -1164,7 +1172,7 @@ and :ref:`aiohttp-web-signals` handlers::
11641172
The method is converted into :term:`coroutine`,
11651173
*compress* parameter added.
11661174

1167-
.. method:: close(*, code=WSCloseCode.OK, message=b'')
1175+
.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
11681176
:async:
11691177

11701178
A :ref:`coroutine<coroutine>` that initiates closing
@@ -1178,6 +1186,8 @@ and :ref:`aiohttp-web-signals` handlers::
11781186
:class:`str` (converted to *UTF-8* encoded bytes)
11791187
or :class:`bytes`.
11801188

1189+
:param bool drain: drain outgoing buffer before closing connection.
1190+
11811191
:raise RuntimeError: if connection is not started
11821192

11831193
.. method:: receive(timeout=None)

tests/test_web_websocket.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# type: ignore
22
import asyncio
3+
import time
34
from typing import Any
45
from unittest import mock
56

@@ -139,6 +140,20 @@ async def test_write_non_prepared() -> None:
139140
await ws.write(b"data")
140141

141142

143+
async def test_heartbeat_timeout(make_request: Any) -> None:
144+
"""Verify the transport is closed when the heartbeat timeout is reached."""
145+
loop = asyncio.get_running_loop()
146+
future = loop.create_future()
147+
req = make_request("GET", "/")
148+
lowest_time = time.get_clock_info("monotonic").resolution
149+
req._protocol._timeout_ceil_threshold = lowest_time
150+
ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time)
151+
await ws.prepare(req)
152+
ws._req.transport.close.side_effect = lambda: future.set_result(None)
153+
await future
154+
assert ws.closed
155+
156+
142157
def test_websocket_ready() -> None:
143158
websocket_ready = WebSocketReady(True, "chat")
144159
assert websocket_ready.ok is True
@@ -207,6 +222,7 @@ async def test_send_str_closed(make_request: Any) -> None:
207222
await ws.prepare(req)
208223
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
209224
await ws.close()
225+
assert len(ws._req.transport.close.mock_calls) == 1
210226

211227
with pytest.raises(ConnectionError):
212228
await ws.send_str("string")
@@ -263,6 +279,8 @@ async def test_close_idempotent(make_request: Any) -> None:
263279
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
264280
assert await ws.close(code=1, message="message1")
265281
assert ws.closed
282+
assert len(ws._req.transport.close.mock_calls) == 1
283+
266284
assert not (await ws.close(code=2, message="message2"))
267285

268286

@@ -296,12 +314,15 @@ async def test_write_eof_idempotent(make_request: Any) -> None:
296314
req = make_request("GET", "/")
297315
ws = WebSocketResponse()
298316
await ws.prepare(req)
317+
assert len(ws._req.transport.close.mock_calls) == 0
318+
299319
ws._reader.feed_data(WS_CLOSED_MESSAGE, 0)
300320
await ws.close()
301321

302322
await ws.write_eof()
303323
await ws.write_eof()
304324
await ws.write_eof()
325+
assert len(ws._req.transport.close.mock_calls) == 1
305326

306327

307328
async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None:
@@ -327,6 +348,7 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
327348
req = make_request("GET", "/")
328349
ws = WebSocketResponse()
329350
await ws.prepare(req)
351+
assert len(ws._req.transport.close.mock_calls) == 0
330352

331353
ws._reader = mock.Mock()
332354
res = loop.create_future()
@@ -336,6 +358,8 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
336358
with pytest.raises(asyncio.TimeoutError):
337359
await ws.receive()
338360

361+
assert len(ws._req.transport.close.mock_calls) == 1
362+
339363

340364
async def test_multiple_receive_on_close_connection(make_request: Any) -> None:
341365
req = make_request("GET", "/")
@@ -367,13 +391,15 @@ async def test_close_exc(make_request: Any) -> None:
367391
req = make_request("GET", "/")
368392
ws = WebSocketResponse()
369393
await ws.prepare(req)
394+
assert len(ws._req.transport.close.mock_calls) == 0
370395

371396
exc = ValueError()
372397
ws._writer = mock.Mock()
373398
ws._writer.close.side_effect = exc
374399
await ws.close()
375400
assert ws.closed
376401
assert ws.exception() is exc
402+
assert len(ws._req.transport.close.mock_calls) == 1
377403

378404
ws._closed = False
379405
ws._writer.close.side_effect = asyncio.CancelledError()

0 commit comments

Comments
 (0)