diff --git a/aiohttp_asgi/resource.py b/aiohttp_asgi/resource.py index e985cb1..ac2eb27 100644 --- a/aiohttp_asgi/resource.py +++ b/aiohttp_asgi/resource.py @@ -8,10 +8,12 @@ from aiohttp import ClientRequest, WSMessage, WSMsgType, hdrs from aiohttp.abc import AbstractMatchInfo, AbstractStreamWriter +from aiohttp.helpers import DEBUG from aiohttp.web import ( AbstractResource, Application, HTTPException, Request, StreamResponse, WebSocketResponse, ) +from urllib.parse import unquote from yarl import URL @@ -76,6 +78,7 @@ def __init__(self, handler: Callable[..., Any]): self._handler = handler self._apps = list() # type: _ApplicationColelctionType self._current_app: Optional[Application] = None + self._frozen = False @property def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: @@ -102,27 +105,48 @@ def apps(self) -> Tuple[Application, ...]: return tuple(self._apps) return self._apps - def add_app(self, app: Application) -> None: - if isinstance(self._apps, tuple): - raise RuntimeError("Frozen resource") - - self._apps.append(app) + @property + def apps(self) -> Tuple["Application", ...]: + return tuple(self._apps) + + def add_app(self, app: "Application") -> None: + if self._frozen: + raise RuntimeError("Cannot change apps stack after .freeze() call") + if self._current_app is None: + self._current_app = app + self._apps.insert(0, app) + + # @contextmanager + # def set_current_app( + # self, + # app: Application, + # ) -> Generator[None, None, None]: + # prev = self._current_app + # self._current_app = app + # try: + # yield + # finally: + # self._current_app = prev - def freeze(self) -> None: - self._apps = tuple(self.apps) - - @contextmanager - def set_current_app( - self, - app: Application, - ) -> Generator[None, None, None]: - prev = self._current_app + @property + def current_app(self) -> "Application": + app = self._current_app + assert app is not None + return app + + @current_app.setter + def current_app(self, app: "Application") -> None: + if DEBUG: # pragma: no cover + if app not in self._apps: + raise RuntimeError( + "Expected one of the following apps {!r}, got {!r}".format( + self._apps, app + ) + ) self._current_app = app - try: - yield - finally: - self._current_app = prev + def freeze(self) -> None: + self._frozen = True _ResponseType = Optional[Union[StreamResponse, WebSocketResponse]] _WriterType = Optional[AbstractStreamWriter] @@ -178,7 +202,15 @@ def scope(self) -> ScopeDict: if self.is_websocket(): result["type"] = "websocket" result["scheme"] = "wss" if self.request.secure else "ws" - result["subprotocols"] = [] + + # Decode websocket subprotocol options + subprotocols = [] + for header, value in result["headers"]: + if header == b"sec-websocket-protocol": + subprotocols = [ + x.strip() for x in unquote(value.decode("ascii")).split(",") + ] + result["subprotocols"] = subprotocols return result @@ -250,7 +282,7 @@ async def on_send(self, payload: Dict[str, Any]) -> None: if self.start_response_event.is_set(): raise asyncio.InvalidStateError - self.response = WebSocketResponse() + self.response = WebSocketResponse(protocols=self.scope["subprotocols"]) self.writer = await self.response.prepare(self.request) return