diff --git a/litestar/handlers/http_handlers/base.py b/litestar/handlers/http_handlers/base.py index 7796e40d63..e81c342c6c 100644 --- a/litestar/handlers/http_handlers/base.py +++ b/litestar/handlers/http_handlers/base.py @@ -2,6 +2,8 @@ from enum import Enum from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast +import functools +import inspect from litestar._layers.utils import narrow_response_cookies, narrow_response_headers from litestar.connection import Request @@ -61,6 +63,14 @@ __all__ = ("HTTPRouteHandler", "route") +def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None: + """Given a list of callables, starting from the end, """ + if not hooks: + return None + if 'parent' in inspect.signature(hooks[-1]).parameters: + return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1])) + return hooks[-1] + class ResponseHandlerMap(TypedDict): default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType @@ -260,9 +270,9 @@ def __init__( ) self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None + self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore self.background = background - self.before_request = ensure_async_callable(before_request) if before_request else None + self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore self.cache = cache self.cache_control = cache_control self.cache_key_builder = cache_key_builder @@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None: """ if self._resolved_before_request is Empty: before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request] - self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None + self._resolved_before_request = _wrap_layered_hooks(before_request_handlers) return cast("AsyncAnyCallable | None", self._resolved_before_request) def resolve_after_response(self) -> AsyncAnyCallable | None: @@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None: for layer in self.ownership_layers if layer.after_response ] - self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None + self._resolved_after_response = _wrap_layered_hooks(after_response_handlers) return cast("AsyncAnyCallable | None", self._resolved_after_response) diff --git a/litestar/handlers/http_handlers/decorators.py b/litestar/handlers/http_handlers/decorators.py index 593a1a7d19..227e34b858 100644 --- a/litestar/handlers/http_handlers/decorators.py +++ b/litestar/handlers/http_handlers/decorators.py @@ -121,8 +121,9 @@ def __init__( :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished. Defaults to ``None``. before_request: A sync or async function called immediately before calling the route handler. Receives - the :class:`.connection.Request` instance and any non-``None`` return value is used for the response, - bypassing the route handler. + the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the + outer scope's before_request handler if any exists). Any non-``None`` return value is used for the + response, bypassing the route handler. cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number of seconds (e.g. ``120``) to cache the response. cache_control: A ``cache-control`` header of type diff --git a/litestar/router.py b/litestar/router.py index 88ac0fd567..914170dcc2 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -168,8 +168,8 @@ def __init__( """ self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore - self.after_response = ensure_async_callable(after_response) if after_response else None - self.before_request = ensure_async_callable(before_request) if before_request else None + self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore + self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore self.cache_control = cache_control self.dto = dto self.etag = etag diff --git a/litestar/types/callable_types.py b/litestar/types/callable_types.py index 36055d7199..d651a0c790 100644 --- a/litestar/types/callable_types.py +++ b/litestar/types/callable_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -23,12 +23,25 @@ AfterRequestHookHandler: TypeAlias = ( "Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]" ) -AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + +AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]" + +class AfterResponseHookHandlerWithParent(Protocol): + async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None" = None) -> None: ... + +AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent" + AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]] AnyCallable: TypeAlias = Callable[..., Any] AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]" BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]" -BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" + +class BeforeRequestHookHandlerWithParent(Protocol): + async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None" = None) -> Any: ... + +BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]" +BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent" + CacheKeyBuilder: TypeAlias = "Callable[[Request], str]" ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]" ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]" diff --git a/tests/e2e/test_life_cycle_hooks/test_before_request.py b/tests/e2e/test_life_cycle_hooks/test_before_request.py index 47de0a58a8..28bd481e1a 100644 --- a/tests/e2e/test_life_cycle_hooks/test_before_request.py +++ b/tests/e2e/test_life_cycle_hooks/test_before_request.py @@ -1,12 +1,22 @@ from typing import Any, Dict, Optional +import logging import pytest from litestar import Controller, Request, Response, Router, get +from litestar.handlers.http_handlers.base import _wrap_layered_hooks from litestar.datastructures import State from litestar.testing import create_test_client from litestar.types import AnyCallable, BeforeRequestHookHandler +logger = logging.getLogger(__name__) + +async def async_before_request_handler_with_parent(request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None): + assert isinstance(request, Request) + retval = (None if parent is None else await parent(request)) or {} + retval.setdefault('amended_count', 0) + retval['amended_count'] += 1 + return retval def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]: assert isinstance(request, Request) @@ -88,6 +98,9 @@ def handler() -> Dict[str, str]: {"hello": "world"}, ], [None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}], + [sync_before_request_handler_with_return_value, None, None, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 1}], + [sync_before_request_handler_with_return_value, None, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 2}], + [sync_before_request_handler_with_return_value, async_before_request_handler_with_parent, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 3}], ], ) def test_before_request_handler_resolution(