Skip to content

Commit

Permalink
feat: allow before_request and after_request handlers to accept a par…
Browse files Browse the repository at this point in the history
…ent argument, to wrap or override the handler from an enclosing scope
  • Loading branch information
charles-dyfis-net committed Sep 21, 2024
1 parent a38c6c1 commit b0b45a6
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 11 deletions.
18 changes: 14 additions & 4 deletions litestar/handlers/http_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from enum import Enum
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast
import functools
import inspect

from litestar._layers.utils import narrow_response_cookies, narrow_response_headers
from litestar.connection import Request
Expand Down Expand Up @@ -61,6 +63,14 @@

__all__ = ("HTTPRouteHandler", "route")

def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None:
"""Given a list of callables, starting from the end, """
if not hooks:
return None
if 'parent' in inspect.signature(hooks[-1]).parameters:
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))
return hooks[-1]


class ResponseHandlerMap(TypedDict):
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
Expand Down Expand Up @@ -260,9 +270,9 @@ def __init__(
)

self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
self.after_response = ensure_async_callable(after_response) if after_response else None
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
self.background = background
self.before_request = ensure_async_callable(before_request) if before_request else None
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
self.cache = cache
self.cache_control = cache_control
self.cache_key_builder = cache_key_builder
Expand Down Expand Up @@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
"""
if self._resolved_before_request is Empty:
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
return cast("AsyncAnyCallable | None", self._resolved_before_request)

def resolve_after_response(self) -> AsyncAnyCallable | None:
Expand All @@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
for layer in self.ownership_layers
if layer.after_response
]
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)

return cast("AsyncAnyCallable | None", self._resolved_after_response)

Expand Down
5 changes: 3 additions & 2 deletions litestar/handlers/http_handlers/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def __init__(
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to ``None``.
before_request: A sync or async function called immediately before calling the route handler. Receives
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
bypassing the route handler.
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
response, bypassing the route handler.
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
of seconds (e.g. ``120``) to cache the response.
cache_control: A ``cache-control`` header of type
Expand Down
4 changes: 2 additions & 2 deletions litestar/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def __init__(
"""

self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
self.after_response = ensure_async_callable(after_response) if after_response else None
self.before_request = ensure_async_callable(before_request) if before_request else None
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
self.cache_control = cache_control
self.dto = dto
self.etag = etag
Expand Down
19 changes: 16 additions & 3 deletions litestar/types/callable_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar

if TYPE_CHECKING:
from typing_extensions import TypeAlias
Expand All @@ -23,12 +23,25 @@
AfterRequestHookHandler: TypeAlias = (
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
)
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"

AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"

class AfterResponseHookHandlerWithParent(Protocol):
async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None" = None) -> None: ...

AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent"

AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
AnyCallable: TypeAlias = Callable[..., Any]
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"

class BeforeRequestHookHandlerWithParent(Protocol):
async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None" = None) -> Any: ...

BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent"

CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"
Expand Down
13 changes: 13 additions & 0 deletions tests/e2e/test_life_cycle_hooks/test_before_request.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from typing import Any, Dict, Optional
import logging

import pytest

from litestar import Controller, Request, Response, Router, get
from litestar.handlers.http_handlers.base import _wrap_layered_hooks
from litestar.datastructures import State
from litestar.testing import create_test_client
from litestar.types import AnyCallable, BeforeRequestHookHandler

logger = logging.getLogger(__name__)

async def async_before_request_handler_with_parent(request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None):
assert isinstance(request, Request)
retval = (None if parent is None else await parent(request)) or {}
retval.setdefault('amended_count', 0)
retval['amended_count'] += 1
return retval

def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
assert isinstance(request, Request)
Expand Down Expand Up @@ -88,6 +98,9 @@ def handler() -> Dict[str, str]:
{"hello": "world"},
],
[None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}],
[sync_before_request_handler_with_return_value, None, None, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 1}],
[sync_before_request_handler_with_return_value, None, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 2}],
[sync_before_request_handler_with_return_value, async_before_request_handler_with_parent, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 3}],
],
)
def test_before_request_handler_resolution(
Expand Down

0 comments on commit b0b45a6

Please sign in to comment.