Skip to content

Commit

Permalink
fix(ASGI mounts): Prevent accidental scope overrides by mounted ASGI …
Browse files Browse the repository at this point in the history
…apps (#3945)
  • Loading branch information
provinzkraut authored Jan 11, 2025
1 parent a814224 commit 2db1f4d
Show file tree
Hide file tree
Showing 24 changed files with 135 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/examples/application_hooks/after_exception_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def my_handler() -> None:

async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
"""Hook function that will be invoked after each exception."""
state = scope["app"].state
state = Litestar.from_scope(scope).state
if not hasattr(state, "error_count"):
state.error_count = 1
else:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/application_hooks/before_send_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def before_send_hook_handler(message: Message, scope: Scope) -> None:
"""
if message["type"] == "http.response.start":
headers = MutableScopeHeaders.from_message(message=message)
headers["My Header"] = scope["app"].state.message
headers["My Header"] = Litestar.from_scope(scope).state.message


def on_startup(app: Litestar) -> None:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/application_state/using_application_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def middleware_factory(*, app: "ASGIApp") -> "ASGIApp":
"""A middleware can access application state via `scope`."""

async def my_middleware(scope: "Scope", receive: "Receive", send: "Send") -> None:
state = scope["app"].state
state = Litestar.from_scope(scope).state
logger.info("state value in middleware: %s", state.value)
await app(scope, receive, send)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/routing/mount_custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from litestar.types import Receive, Scope, Send


@asgi("/some/sub-path", is_mount=True)
@asgi("/some/sub-path", is_mount=True, copy_scope=True)
async def my_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None:
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/routing/mounting_starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def index(request: "Request") -> JSONResponse:
return JSONResponse({"forwarded_path": request.url.path})


starlette_app = asgi(path="/some/sub-path", is_mount=True)(
starlette_app = asgi(path="/some/sub-path", is_mount=True, copy_scope=True)(
Starlette(
routes=[
Route("/", index),
Expand Down
8 changes: 4 additions & 4 deletions docs/usage/applications.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ is accessible.
:ref:`reserved keyword arguments <usage/routing/handlers:"reserved" keyword arguments>`.

It is important to understand in this context that the application instance is injected into the ASGI ``scope`` mapping
for each connection (i.e. request or websocket connection) as ``scope["app"]``. This makes the application
accessible wherever the scope mapping is available, e.g. in middleware, on :class:`~.connection.request.Request` and
:class:`~.connection.websocket.WebSocket` instances (accessible as ``request.app`` / ``socket.app``), and many
other places.
for each connection (i.e. request or websocket connection) as ``scope["litestar_app"]``, and can be retrieved using
:meth:`~.Litestar.from_scope`. This makes the application accessible wherever the scope mapping is available,
e.g. in middleware, on :class:`~.connection.request.Request` and :class:`~.connection.websocket.WebSocket` instances
(accessible as ``request.app`` / ``socket.app``), and many other places.

Therefore, :paramref:`~.app.Litestar.state` offers an easy way to share contextual data between disparate parts
of the application, as seen below:
Expand Down
7 changes: 6 additions & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,15 @@ async def __call__(
await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type]
return

scope["app"] = self
scope["app"] = scope["litestar_app"] = self
scope.setdefault("state", {})
await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type]

@classmethod
def from_scope(cls, scope: Scope) -> Litestar:
"""Retrieve the Litestar application from the current ASGI scope"""
return scope["litestar_app"]

async def _call_lifespan_hook(self, hook: LifespanHook) -> None:
ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore[call-arg]

Expand Down
6 changes: 3 additions & 3 deletions litestar/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def app(self) -> Litestar:
Returns:
The :class:`Litestar <litestar.app.Litestar>` application instance
"""
return self.scope["app"]
return self.scope["litestar_app"]

@property
def route_handler(self) -> HandlerT:
Expand Down Expand Up @@ -321,7 +321,7 @@ def url_for(self, name: str, **path_parameters: Any) -> str:
Returns:
A string representing the absolute url of the route handler.
"""
litestar_instance = self.scope["app"]
litestar_instance = self.scope["litestar_app"]
url_path = litestar_instance.route_reverse(name, **path_parameters)

return make_absolute_url(url_path, self.base_url)
Expand All @@ -339,7 +339,7 @@ def url_for_static_asset(self, name: str, file_path: str) -> str:
Returns:
A string representing absolute url to the asset.
"""
litestar_instance = self.scope["app"]
litestar_instance = self.scope["litestar_app"]
url_path = litestar_instance.url_for_static_asset(name, file_path)

return make_absolute_url(url_path, self.base_url)
22 changes: 21 additions & 1 deletion litestar/handlers/asgi_handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Mapping, Sequence

from litestar.exceptions import ImproperlyConfiguredException
Expand All @@ -11,6 +12,7 @@


if TYPE_CHECKING:
from litestar import Litestar
from litestar.types import (
ExceptionHandlersMap,
Guard,
Expand All @@ -24,7 +26,7 @@ class ASGIRouteHandler(BaseRouteHandler):
Use this decorator to decorate ASGI applications.
"""

__slots__ = ("is_mount", "is_static")
__slots__ = ("copy_scope", "is_mount", "is_static")

def __init__(
self,
Expand All @@ -37,6 +39,7 @@ def __init__(
is_mount: bool = False,
is_static: bool = False,
signature_namespace: Mapping[str, Any] | None = None,
copy_scope: bool | None = None,
**kwargs: Any,
) -> None:
"""Initialize ``ASGIRouteHandler``.
Expand All @@ -58,10 +61,14 @@ def __init__(
are used to deliver static files.
signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
copy_scope: Copy the ASGI 'scope' before calling the mounted application. Should be set to 'True' unless
side effects via scope mutations by the mounted ASGI application are intentional
**kwargs: Any additional kwarg - will be set in the opt dictionary.
"""
self.is_mount = is_mount or is_static
self.is_static = is_static
self.copy_scope = copy_scope

super().__init__(
path,
exception_handlers=exception_handlers,
Expand All @@ -72,6 +79,19 @@ def __init__(
**kwargs,
)

def on_registration(self, app: Litestar) -> None:
super().on_registration(app)

if self.copy_scope is None:
warnings.warn(
f"{self}: 'copy_scope' not set for ASGI handler. Leaving 'copy_scope' unset will warn about mounted "
"ASGI applications modifying the scope. Set 'copy_scope=True' to ensure calling into mounted ASGI apps "
"does not cause any side effects via scope mutations, or set 'copy_scope=False' if those mutations are "
"desired. 'copy'scope' will default to 'True' in Litestar 3",
category=DeprecationWarning,
stacklevel=1,
)

def _validate_handler_function(self) -> None:
"""Validate the route handler function once it's set by inspecting its return annotations."""
super()._validate_handler_function()
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/_internal/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
origin = headers.get("origin")

if scope["type"] == ScopeType.HTTP and scope["method"] == HttpMethod.OPTIONS and origin:
request = scope["app"].request_class(scope=scope, receive=receive, send=send)
request = scope["litestar_app"].request_class(scope=scope, receive=receive, send=send)
asgi_response = self._create_preflight_response(origin=origin, request_headers=headers).to_asgi_response(
app=None, request=request
)
Expand Down
4 changes: 2 additions & 2 deletions litestar/middleware/_internal/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(

@staticmethod
def _get_debug_scope(scope: Scope) -> bool:
return scope["app"].debug
return scope["litestar_app"].debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI-callable.
Expand All @@ -161,7 +161,7 @@ async def capture_response_started(event: Message) -> None:
if scope_state.response_started:
raise LitestarException("Exception caught after response started") from e

litestar_app = scope["app"]
litestar_app = scope["litestar_app"]

if litestar_app.logging_config and (logger := litestar_app.logger):
self.handle_exception_logging(logger=logger, logging_config=litestar_app.logging_config, scope=scope)
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive)
request: Request[Any, Any, Any] = scope["litestar_app"].request_class(scope=scope, receive=receive)
content_type, _ = request.content_type
csrf_cookie = request.cookies.get(self.config.cookie_name)
existing_csrf_token = request.headers.get(self.config.header_name)
Expand Down
4 changes: 2 additions & 2 deletions litestar/middleware/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
None
"""
if not hasattr(self, "logger"):
self.logger = scope["app"].get_logger(self.config.logger_name)
self.logger = scope["litestar_app"].get_logger(self.config.logger_name)
self.is_struct_logger = structlog_installed and repr(self.logger).startswith("<BoundLoggerLazyProxy")

if self.config.response_log_fields:
Expand All @@ -121,7 +121,7 @@ async def log_request(self, scope: Scope, receive: Receive) -> None:
Returns:
None
"""
extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope, receive))
extracted_data = await self.extract_request_data(request=scope["litestar_app"].request_class(scope, receive))
self.log_message(values=extracted_data)

def log_response(self, scope: Scope) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Returns:
None
"""
app = scope["app"]
app = scope["litestar_app"]
request: Request[Any, Any, Any] = app.request_class(scope)
store = self.config.get_store_from_app(app)
if await self.should_check_request(request=request):
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/response_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def wrapped_send(message: Message) -> None:

if messages and message["type"] == HTTP_RESPONSE_BODY and not message.get("more_body"):
key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope))
store = self.config.get_store_from_app(scope["app"])
store = self.config.get_store_from_app(scope["litestar_app"])
await store.set(key, encode_msgpack(messages), expires_in=expires_in)
await send(message)

Expand Down
21 changes: 20 additions & 1 deletion litestar/routes/asgi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

from litestar.connection import ASGIConnection
from litestar.enums import ScopeType
from litestar.exceptions import LitestarWarning
from litestar.routes.base import BaseRoute

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,4 +53,21 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive)
await self.route_handler.authorize_connection(connection=connection)

await self.route_handler.fn(scope=scope, receive=receive, send=send)
handler_scope = scope.copy()
copy_scope = self.route_handler.copy_scope

await self.route_handler.fn(
scope=handler_scope if copy_scope is True else scope,
receive=receive,
send=send,
)

if copy_scope is None and handler_scope != scope:
warnings.warn(
f"{self.route_handler}: Mounted ASGI app {self.route_handler.fn} modified 'scope' with 'copy_scope' "
"set to 'None'. Set 'copy_scope=True' to avoid mutating the original scope or set 'copy_scope=False' "
"if mutating the scope from within the mounted ASGI app is intentional. Note: 'copy_scope' will "
"default to 'True' by default in Litestar 3",
category=LitestarWarning,
stacklevel=1,
)
4 changes: 3 additions & 1 deletion litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ async def _call_handler_function(
route_handler=route_handler, parameter_model=parameter_model, request=request
)

response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)
response: ASGIApp = await route_handler.to_response(
app=scope["litestar_app"], data=response_data, request=request
)

if cleanup_group:
await cleanup_group.cleanup()
Expand Down
1 change: 1 addition & 0 deletions litestar/testing/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnectio
"http_version": "1.1",
"extensions": {"http.response.template": {}},
"app": app, # type: ignore[typeddict-item]
"litestar_app": app,
"state": {},
"path_params": {},
"route_handler": None,
Expand Down
3 changes: 2 additions & 1 deletion litestar/testing/request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
"""Initialize ``RequestFactory``
Args:
app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["app"]``.
app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["litestar_app"]``.
server: The server's domain.
port: The server's port.
root_path: Root path for the server.
Expand Down Expand Up @@ -175,6 +175,7 @@ def _create_scope(
path=path,
headers=[],
app=self.app,
litestar_app=self.app,
session=session,
user=user,
auth=auth,
Expand Down
3 changes: 2 additions & 1 deletion litestar/types/asgi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class HeaderScope(TypedDict):
class BaseScope(HeaderScope):
"""Base ASGI-scope."""

app: Litestar
app: Litestar # deprecated
litestar_app: Litestar
asgi: ASGIVersion
auth: Any
client: tuple[str, int] | None
Expand Down
2 changes: 1 addition & 1 deletion litestar/utils/scope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_serializer_from_scope(scope: Scope) -> Serializer:
A serializer function
"""
route_handler = scope["route_handler"]
app = scope["app"]
app = scope["litestar_app"]

if hasattr(route_handler, "resolve_type_encoders"):
type_encoders = route_handler.resolve_type_encoders()
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def inner(
) -> Scope:
scope = {
"app": app,
"litestar_app": app,
"asgi": asgi or {"spec_version": "2.0", "version": "3.0"},
"auth": auth,
"type": type,
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def handler() -> Dict[str, str]:
async def before_send_hook_handler(message: Message, scope: Scope) -> None:
if message["type"] == "http.response.start":
headers = MutableScopeHeaders(message)
headers.add("My Header", scope["app"].state.message)
headers.add("My Header", Litestar.from_scope(scope).state.message)

def on_startup(app: Litestar) -> None:
app.state.message = "value injected during send"
Expand Down Expand Up @@ -466,3 +466,19 @@ def my_route_handler() -> None: ...
with create_test_client(my_route_handler, path="/abc") as client:
response = client.get("/abc")
assert response.status_code == HTTP_200_OK


def test_from_scope() -> None:
mock = MagicMock()

@get()
def handler(scope: Scope) -> None:
mock(Litestar.from_scope(scope))
return

app = Litestar(route_handlers=[handler])

with TestClient(app) as client:
client.get("/")

mock.assert_called_once_with(app)
Loading

0 comments on commit 2db1f4d

Please sign in to comment.