@@ -162,9 +162,8 @@ def _send_heartbeat(self) -> None:
162
162
def _pong_not_received (self ) -> None :
163
163
if self ._req is not None and self ._req .transport is not None :
164
164
self ._closed = True
165
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
165
+ self ._set_code_close_transport ( WSCloseCode .ABNORMAL_CLOSURE )
166
166
self ._exception = asyncio .TimeoutError ()
167
- self ._req .transport .close ()
168
167
169
168
async def prepare (self , request : BaseRequest ) -> AbstractStreamWriter :
170
169
# make pre-check to don't hide it by do_handshake() exceptions
@@ -382,7 +381,10 @@ async def write_eof(self) -> None: # type: ignore[override]
382
381
await self .close ()
383
382
self ._eof_sent = True
384
383
385
- async def close (self , * , code : int = WSCloseCode .OK , message : bytes = b"" ) -> bool :
384
+ async def close (
385
+ self , * , code : int = WSCloseCode .OK , message : bytes = b"" , drain : bool = True
386
+ ) -> bool :
387
+ """Close websocket connection."""
386
388
if self ._writer is None :
387
389
raise RuntimeError ("Call .prepare() first" )
388
390
@@ -396,46 +398,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
396
398
reader .feed_data (WS_CLOSING_MESSAGE , 0 )
397
399
await self ._waiting
398
400
399
- if not self ._closed :
400
- self ._closed = True
401
- try :
402
- await self ._writer .close (code , message )
403
- writer = self ._payload_writer
404
- assert writer is not None
405
- await writer .drain ()
406
- except (asyncio .CancelledError , asyncio .TimeoutError ):
407
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
408
- raise
409
- except Exception as exc :
410
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
411
- self ._exception = exc
412
- return True
401
+ if self ._closed :
402
+ return False
413
403
414
- if self ._closing :
415
- return True
404
+ self ._closed = True
405
+ try :
406
+ await self ._writer .close (code , message )
407
+ writer = self ._payload_writer
408
+ assert writer is not None
409
+ if drain :
410
+ await writer .drain ()
411
+ except (asyncio .CancelledError , asyncio .TimeoutError ):
412
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
413
+ raise
414
+ except Exception as exc :
415
+ self ._exception = exc
416
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
417
+ return True
416
418
417
- reader = self ._reader
418
- assert reader is not None
419
- try :
420
- async with async_timeout .timeout (self ._timeout ):
421
- msg = await reader .read ()
422
- except asyncio .CancelledError :
423
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
424
- raise
425
- except Exception as exc :
426
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
427
- self ._exception = exc
428
- return True
419
+ if self ._closing :
420
+ return True
429
421
430
- if msg .type == WSMsgType .CLOSE :
431
- self ._close_code = msg .data
432
- return True
422
+ reader = self ._reader
423
+ assert reader is not None
424
+ try :
425
+ async with async_timeout .timeout (self ._timeout ):
426
+ msg = await reader .read ()
427
+ except asyncio .CancelledError :
428
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
429
+ raise
430
+ except Exception as exc :
431
+ self ._exception = exc
432
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
433
+ return True
433
434
434
- self . _close_code = WSCloseCode . ABNORMAL_CLOSURE
435
- self ._exception = asyncio . TimeoutError ( )
435
+ if msg . type == WSMsgType . CLOSE :
436
+ self ._set_code_close_transport ( msg . data )
436
437
return True
437
- else :
438
- return False
438
+
439
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
440
+ self ._exception = asyncio .TimeoutError ()
441
+ return True
442
+
443
+ def _set_code_close_transport (self , code : WSCloseCode ) -> None :
444
+ """Set the close code and close the transport."""
445
+ self ._close_code = code
446
+ if self ._req is not None and self ._req .transport is not None :
447
+ self ._req .transport .close ()
439
448
440
449
async def receive (self , timeout : Optional [float ] = None ) -> WSMessage :
441
450
if self ._reader is None :
@@ -466,7 +475,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
466
475
set_result (waiter , True )
467
476
self ._waiting = None
468
477
except (asyncio .CancelledError , asyncio .TimeoutError ):
469
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
478
+ self ._set_code_close_transport ( WSCloseCode .ABNORMAL_CLOSURE )
470
479
raise
471
480
except EofStream :
472
481
self ._close_code = WSCloseCode .OK
@@ -488,7 +497,11 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
488
497
self ._close_code = msg .data
489
498
# Could be closed while awaiting reader.
490
499
if not self ._closed and self ._autoclose : # type: ignore[redundant-expr]
491
- await self .close ()
500
+ # The client is likely going to close the
501
+ # connection out from under us so we do not
502
+ # want to drain any pending writes as it will
503
+ # likely result writing to a broken pipe.
504
+ await self .close (drain = False )
492
505
elif msg .type == WSMsgType .CLOSING :
493
506
self ._closing = True
494
507
elif msg .type == WSMsgType .PING and self ._autoping :
0 commit comments