diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index a21443dff1b..4edc54ba31b 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -382,7 +382,10 @@ async def write_eof(self) -> None: # type: ignore[override] await self.close() self._eof_sent = True - async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: + async def close( + self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True + ) -> bool: + """Close websocket connection.""" if self._writer is None: raise RuntimeError("Call .prepare() first") @@ -396,46 +399,52 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo reader.feed_data(WS_CLOSING_MESSAGE, 0) await self._waiting - if not self._closed: - self._closed = True - try: - await self._writer.close(code, message) - writer = self._payload_writer - assert writer is not None - await writer.drain() - except (asyncio.CancelledError, asyncio.TimeoutError): - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - raise - except Exception as exc: - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = exc - return True + if self._closed: + return False - if self._closing: - return True + self._closed = True + try: + await self._writer.close(code, message) + writer = self._payload_writer + assert writer is not None + if drain: + await writer.drain() + except (asyncio.CancelledError, asyncio.TimeoutError): + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + raise + except Exception as exc: + self._exception = exc + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + return True - reader = self._reader - assert reader is not None - try: - async with async_timeout.timeout(self._timeout): - msg = await reader.read() - except asyncio.CancelledError: - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - raise - except Exception as exc: - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = exc - return True + if self._closing: + return True - if msg.type == WSMsgType.CLOSE: - self._close_code = msg.data - return True + reader = self._reader + assert reader is not None + try: + async with async_timeout.timeout(self._timeout): + msg = await reader.read() + except asyncio.CancelledError: + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + raise + except Exception as exc: + self._exception = exc + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + return True - self._close_code = WSCloseCode.ABNORMAL_CLOSURE - self._exception = asyncio.TimeoutError() + if msg.type == WSMsgType.CLOSE: + self._set_code_close_transport(msg.data) return True - else: - return False + + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + self._exception = asyncio.TimeoutError() + return True + + def _set_code_close_transport(self, code: WSCloseCode) -> None: + """Set the close code and close the transport.""" + self._close_code = code + self._writer.transport.close() async def receive(self, timeout: Optional[float] = None) -> WSMessage: if self._reader is None: @@ -488,7 +497,11 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: self._close_code = msg.data # Could be closed while awaiting reader. if not self._closed and self._autoclose: # type: ignore[redundant-expr] - await self.close() + # The client is going to close the connection + # out from under us so we do not want to drain + # any pending writes as it will likely result + # writing to a broken pipe. + await self.close(drain=False) elif msg.type == WSMsgType.CLOSING: self._closing = True elif msg.type == WSMsgType.PING and self._autoping: