diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index 7796e40d63..0114e4c9ca 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools +import inspect from enum import Enum from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast @@ -62,6 +64,15 @@ __all__ = ("HTTPRouteHandler", "route") +def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None: + """Given a list of callables, starting from the end, set the parent= keyword argument of each to default to the preceding hook should any preceding hook exist and should that argument be accepted.""" + if not hooks: + return None + if "parent" in inspect.signature(hooks[-1]).parameters: + return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1])) + return hooks[-1] + + class ResponseHandlerMap(TypedDict): default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType response_type_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType @@ -260,9 +271,9 @@ def __init__( ) self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None + self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore self.background = background - self.before_request = ensure_async_callable(before_request) if before_request else None + self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore self.cache = cache self.cache_control = cache_control self.cache_key_builder = cache_key_builder @@ -400,7 +411,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None: """ if self._resolved_before_request is Empty: before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request] - self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None + self._resolved_before_request = _wrap_layered_hooks(before_request_handlers) return cast("AsyncAnyCallable | None", self._resolved_before_request) def resolve_after_response(self) -> AsyncAnyCallable | None: @@ -418,7 +429,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None: for layer in self.ownership_layers if layer.after_response ] - self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None + self._resolved_after_response = _wrap_layered_hooks(after_response_handlers) return cast("AsyncAnyCallable | None", self._resolved_after_response) diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 593a1a7d19..227e34b858 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -121,8 +121,9 @@ def __init__( :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to ``None``. before_request: A sync or async function called immediately before calling the route handler. Receives - the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, - bypassing the route handler. + the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the + outer scope's before_request handler if any exists). Any non-``None`` return value is used for the + response, bypassing the route handler. cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number of seconds (e.g. ``120``) to cache the response. cache_control: A ``cache-control`` header of type diff --git a/litestar/router.py b/litestar/router.py index 88ac0fd567..914170dcc2 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -168,8 +168,8 @@ def __init__( """ self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None - self.before_request = ensure_async_callable(before_request) if before_request else None + self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore + self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore self.cache_control = cache_control self.dto = dto self.etag = etag diff --git a/litestar/types/callable_types.py b/litestar/types/callable_types.py index 36055d7199..438bb0db33 100644 --- a/litestar/types/callable_types.py +++ b/litestar/types/callable_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -23,12 +23,29 @@ AfterRequestHookHandler: TypeAlias = ( "Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]" ) -AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + +AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + + +class AfterResponseHookHandlerWithParent(Protocol): + async def __call__(self, request: Request, /, *, parent: AfterResponseHookHandler | None = None) -> None: ... + + +AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent" + AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]] AnyCallable: TypeAlias = Callable[..., Any] AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]" BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]" -BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" + + +class BeforeRequestHookHandlerWithParent(Protocol): + async def __call__(self, request: Request, /, *, parent: BeforeRequestHookHandler | None = None) -> Any: ... + + +BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" +BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent" + CacheKeyBuilder: TypeAlias = "Callable[[Request], str]" ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]" ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]" diff --git a/tests/e2e/test_life_cycle_hooks/test_before_request.py b/tests/e2e/test_life_cycle_hooks/test_before_request.py index 47de0a58a8..45c957c6f1 100644 --- a/tests/e2e/test_life_cycle_hooks/test_before_request.py +++ b/tests/e2e/test_life_cycle_hooks/test_before_request.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional +import logging +from typing import Any, Dict, Optional, Union import pytest @@ -7,6 +8,18 @@ from litestar.testing import create_test_client from litestar.types import AnyCallable, BeforeRequestHookHandler +logger = logging.getLogger(__name__) + + +async def async_before_request_handler_with_parent( + request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None +) -> Optional[Dict[str, Union[str, int]]]: + assert isinstance(request, Request) + retval: Dict[str, Union[str, int]] = (None if parent is None else await parent(request)) or {} + retval.setdefault("amended_count", 0) + retval["amended_count"] += 1 # type: ignore + return retval + def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]: assert isinstance(request, Request) @@ -88,6 +101,27 @@ def handler() -> Dict[str, str]: {"hello": "world"}, ], [None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}], + [ + sync_before_request_handler_with_return_value, + None, + None, + async_before_request_handler_with_parent, + {"hello": "moon", "amended_count": 1}, + ], + [ + sync_before_request_handler_with_return_value, + None, + async_before_request_handler_with_parent, + async_before_request_handler_with_parent, + {"hello": "moon", "amended_count": 2}, + ], + [ + sync_before_request_handler_with_return_value, + async_before_request_handler_with_parent, + async_before_request_handler_with_parent, + async_before_request_handler_with_parent, + {"hello": "moon", "amended_count": 3}, + ], ], ) def test_before_request_handler_resolution(