diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index 7796e40d63..fd31c7ac3c 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -2,6 +2,8 @@ from enum import Enum from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast +import functools +import inspect from litestar._layers.utils import narrow_response_cookies, narrow_response_headers from litestar.connection import Request @@ -61,6 +63,14 @@ __all__ = ("HTTPRouteHandler", "route") +def _wrap_layered_hooks(hooks: list[Callable]) -> AsyncAnyCallable | None: + """Given a list of callables, starting from the end, """ + if not hooks: + return None + if 'parent' in inspect.signature(hooks[-1]).parameters: + return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1])) + return hooks[-1] + class ResponseHandlerMap(TypedDict): default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType @@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None: """ if self._resolved_before_request is Empty: before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request] - self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None + self._resolved_before_request = _wrap_layered_hooks(before_request_handlers) return cast("AsyncAnyCallable | None", self._resolved_before_request) def resolve_after_response(self) -> AsyncAnyCallable | None: @@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None: for layer in self.ownership_layers if layer.after_response ] - self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None + self._resolved_after_response = _wrap_layered_hooks(after_response_handlers) return cast("AsyncAnyCallable | None", self._resolved_after_response) diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 593a1a7d19..227e34b858 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -121,8 +121,9 @@ def __init__( :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to ``None``. before_request: A sync or async function called immediately before calling the route handler. Receives - the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, - bypassing the route handler. + the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the + outer scope's before_request handler if any exists). Any non-``None`` return value is used for the + response, bypassing the route handler. cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number of seconds (e.g. ``120``) to cache the response. cache_control: A ``cache-control`` header of type diff --git a/litestar/types/callable_types.py b/litestar/types/callable_types.py index 36055d7199..1ebbf0f894 100644 --- a/litestar/types/callable_types.py +++ b/litestar/types/callable_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -23,12 +23,29 @@ AfterRequestHookHandler: TypeAlias = ( "Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]" ) -AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + +AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + +class AfterResponseHookHandlerWithParentAsync(Protocol): + async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ... +class AfterResponseHookHandlerWithParentSync(Protocol): + def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ... + +AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParentAsync | AfterResponseHookHandlerWithParentSync" + AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]] AnyCallable: TypeAlias = Callable[..., Any] AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]" BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]" -BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" + +class BeforeRequestHookHandlerWithParentAsync(Protocol): + async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ... +class BeforeRequestHookHandlerWithParentSync(Protocol): + def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ... + +BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" +BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParentSync | BeforeRequestHookHandlerWithParentAsync" + CacheKeyBuilder: TypeAlias = "Callable[[Request], str]" ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]" ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"