@@ -140,9 +140,8 @@ def _send_heartbeat(self) -> None:
140
140
def _pong_not_received (self ) -> None :
141
141
if self ._req is not None and self ._req .transport is not None :
142
142
self ._closed = True
143
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
143
+ self ._set_code_close_transport ( WSCloseCode .ABNORMAL_CLOSURE )
144
144
self ._exception = asyncio .TimeoutError ()
145
- self ._req .transport .close ()
146
145
147
146
async def prepare (self , request : BaseRequest ) -> AbstractStreamWriter :
148
147
# make pre-check to don't hide it by do_handshake() exceptions
@@ -360,7 +359,10 @@ async def write_eof(self) -> None: # type: ignore[override]
360
359
await self .close ()
361
360
self ._eof_sent = True
362
361
363
- async def close (self , * , code : int = WSCloseCode .OK , message : bytes = b"" ) -> bool :
362
+ async def close (
363
+ self , * , code : int = WSCloseCode .OK , message : bytes = b"" , drain : bool = True
364
+ ) -> bool :
365
+ """Close websocket connection."""
364
366
if self ._writer is None :
365
367
raise RuntimeError ("Call .prepare() first" )
366
368
@@ -374,46 +376,53 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
374
376
reader .feed_data (WS_CLOSING_MESSAGE , 0 )
375
377
await self ._waiting
376
378
377
- if not self ._closed :
378
- self ._closed = True
379
- try :
380
- await self ._writer .close (code , message )
381
- writer = self ._payload_writer
382
- assert writer is not None
383
- await writer .drain ()
384
- except (asyncio .CancelledError , asyncio .TimeoutError ):
385
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
386
- raise
387
- except Exception as exc :
388
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
389
- self ._exception = exc
390
- return True
379
+ if self ._closed :
380
+ return False
391
381
392
- if self ._closing :
393
- return True
382
+ self ._closed = True
383
+ try :
384
+ await self ._writer .close (code , message )
385
+ writer = self ._payload_writer
386
+ assert writer is not None
387
+ if drain :
388
+ await writer .drain ()
389
+ except (asyncio .CancelledError , asyncio .TimeoutError ):
390
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
391
+ raise
392
+ except Exception as exc :
393
+ self ._exception = exc
394
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
395
+ return True
394
396
395
- reader = self ._reader
396
- assert reader is not None
397
- try :
398
- async with async_timeout .timeout (self ._timeout ):
399
- msg = await reader .read ()
400
- except asyncio .CancelledError :
401
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
402
- raise
403
- except Exception as exc :
404
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
405
- self ._exception = exc
406
- return True
397
+ if self ._closing :
398
+ return True
407
399
408
- if msg .type == WSMsgType .CLOSE :
409
- self ._close_code = msg .data
410
- return True
400
+ reader = self ._reader
401
+ assert reader is not None
402
+ try :
403
+ async with async_timeout .timeout (self ._timeout ):
404
+ msg = await reader .read ()
405
+ except asyncio .CancelledError :
406
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
407
+ raise
408
+ except Exception as exc :
409
+ self ._exception = exc
410
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
411
+ return True
411
412
412
- self . _close_code = WSCloseCode . ABNORMAL_CLOSURE
413
- self ._exception = asyncio . TimeoutError ( )
413
+ if msg . type == WSMsgType . CLOSE :
414
+ self ._set_code_close_transport ( msg . data )
414
415
return True
415
- else :
416
- return False
416
+
417
+ self ._set_code_close_transport (WSCloseCode .ABNORMAL_CLOSURE )
418
+ self ._exception = asyncio .TimeoutError ()
419
+ return True
420
+
421
+ def _set_code_close_transport (self , code : WSCloseCode ) -> None :
422
+ """Set the close code and close the transport."""
423
+ self ._close_code = code
424
+ if self ._req is not None and self ._req .transport is not None :
425
+ self ._req .transport .close ()
417
426
418
427
async def receive (self , timeout : Optional [float ] = None ) -> WSMessage :
419
428
if self ._reader is None :
@@ -444,7 +453,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
444
453
set_result (waiter , True )
445
454
self ._waiting = None
446
455
except (asyncio .CancelledError , asyncio .TimeoutError ):
447
- self ._close_code = WSCloseCode .ABNORMAL_CLOSURE
456
+ self ._set_code_close_transport ( WSCloseCode .ABNORMAL_CLOSURE )
448
457
raise
449
458
except EofStream :
450
459
self ._close_code = WSCloseCode .OK
@@ -464,8 +473,13 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
464
473
if msg .type == WSMsgType .CLOSE :
465
474
self ._closing = True
466
475
self ._close_code = msg .data
467
- if not self ._closed and self ._autoclose :
468
- await self .close ()
476
+ # Could be closed while awaiting reader.
477
+ if not self ._closed and self ._autoclose : # type: ignore[redundant-expr]
478
+ # The client is likely going to close the
479
+ # connection out from under us so we do not
480
+ # want to drain any pending writes as it will
481
+ # likely result writing to a broken pipe.
482
+ await self .close (drain = False )
469
483
elif msg .type == WSMsgType .CLOSING :
470
484
self ._closing = True
471
485
elif msg .type == WSMsgType .PING and self ._autoping :
0 commit comments