diff --git a/grpclib/client.py b/grpclib/client.py index 2f788e6..0fe146c 100644 --- a/grpclib/client.py +++ b/grpclib/client.py @@ -62,7 +62,10 @@ class Handler(AbstractHandler): - connection_lost = False + closing = False + + def connection_made(self, connection: Any) -> None: + pass def accept(self, stream: Any, headers: Any, release_stream: Any) -> None: raise NotImplementedError('Client connection can not accept requests') @@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None: pass def close(self) -> None: - self.connection_lost = True + self.closing = True class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]): @@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol: @property def _connected(self) -> bool: return (self._protocol is not None - and not self._protocol.handler.connection_lost) + and not cast(Handler, self._protocol.handler).closing) async def __connect__(self) -> H2Protocol: if not self._connected: diff --git a/grpclib/protocol.py b/grpclib/protocol.py index 66f66c1..011eee6 100644 --- a/grpclib/protocol.py +++ b/grpclib/protocol.py @@ -488,6 +488,10 @@ def closable(self) -> bool: class AbstractHandler(ABC): + @abstractmethod + def connection_made(self, connection: Connection) -> None: + pass + @abstractmethod def accept( self, @@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None: self.connection.flush() self.connection.initialize() + self.handler.connection_made(self.connection) self.processor = EventsProcessor(self.handler, self.connection) def data_received(self, data: bytes) -> None: diff --git a/grpclib/server.py b/grpclib/server.py index b652a75..504c149 100644 --- a/grpclib/server.py +++ b/grpclib/server.py @@ -4,6 +4,7 @@ import logging import asyncio import warnings +from functools import partial from types import TracebackType from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast @@ -12,6 +13,7 @@ import h2.config import h2.exceptions +from h2.errors import ErrorCodes from multidict import MultiDict @@ -24,7 +26,7 @@ from .metadata import Deadline, encode_grpc_message, _Metadata from .metadata import encode_metadata, decode_metadata, _MetadataLike from .metadata import _STATUS_DETAILS_KEY, encode_bin_value -from .protocol import H2Protocol, AbstractHandler +from .protocol import H2Protocol, AbstractHandler, Connection from .exceptions import GRPCError, ProtocolError, StreamTerminatedError from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec @@ -496,6 +498,7 @@ def __gc_step__(self) -> None: class Handler(_GC, AbstractHandler): __gc_interval__ = 10 + connection: Connection closing = False def __init__( @@ -511,13 +514,17 @@ def __init__( self.dispatch = dispatch self.loop = asyncio.get_event_loop() self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {} - self._cancelled: Set['asyncio.Task[None]'] = set() def __gc_collect__(self) -> None: - self._tasks = {s: t for s, t in self._tasks.items() - if not t.done()} - self._cancelled = {t for t in self._cancelled - if not t.done()} + self._tasks = {s: t for s, t in self._tasks.items() if not t.done()} + + def connection_made(self, connection: Connection) -> None: + self.connection = connection + + def handler_done(self, stream: 'protocol.Stream', _: Any) -> None: + self._tasks.pop(stream, None) + if not self._tasks: + self.connection.close() def accept( self, @@ -525,30 +532,36 @@ def accept( headers: _Headers, release_stream: Callable[[], Any], ) -> None: - self.__gc_step__() - self._tasks[stream] = self.loop.create_task(request_handler( - self.mapping, stream, headers, self.codec, - self.status_details_codec, self.dispatch, release_stream, - )) + if self.closing: + stream.reset_nowait(ErrorCodes.REFUSED_STREAM) + release_stream() + else: + self.__gc_step__() + self._tasks[stream] = self.loop.create_task(request_handler( + self.mapping, stream, headers, self.codec, + self.status_details_codec, self.dispatch, release_stream, + )) def cancel(self, stream: 'protocol.Stream') -> None: - task = self._tasks.pop(stream) - task.cancel() - self._cancelled.add(task) + self._tasks[stream].cancel() def close(self) -> None: - for task in self._tasks.values(): + self.__gc_collect__() + for stream, task in self._tasks.items(): + task.add_done_callback(partial(self.handler_done, stream)) task.cancel() - self._cancelled.update(self._tasks.values()) self.closing = True async def wait_closed(self) -> None: - if self._cancelled: - await asyncio.wait(self._cancelled) + self.__gc_collect__() + if self._tasks: + await asyncio.wait(self._tasks.values()) + else: + self.connection.close() def check_closed(self) -> bool: self.__gc_collect__() - return not self._tasks and not self._cancelled + return not self._tasks class Server(_GC): @@ -737,11 +750,11 @@ async def wait_closed(self) -> None: if self._server is None or self._server_closed_fut is None: raise RuntimeError('Server is not started') await self._server_closed_fut - await self._server.wait_closed() if self._handlers: await asyncio.wait({ self._loop.create_task(h.wait_closed()) for h in self._handlers }) + await self._server.wait_closed() async def __aenter__(self) -> 'Server': return self diff --git a/tests/stubs.py b/tests/stubs.py index e2eb753..c4ac09d 100644 --- a/tests/stubs.py +++ b/tests/stubs.py @@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler): headers = None release_stream = None + def connection_made(self, connection): + pass + def accept(self, stream, headers, release_stream): self.stream = stream self.headers = headers