Skip to content

Commit

Permalink
feat: allow before_request and after_request handlers to accept a par…
Browse files Browse the repository at this point in the history
…ent argument, to wrap or override the handler from an enclosing scope
  • Loading branch information
charles-dyfis-net committed Sep 21, 2024
1 parent a38c6c1 commit adc478e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
14 changes: 12 additions & 2 deletions litestar/handlers/http_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from enum import Enum
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast
import functools
import inspect

from litestar._layers.utils import narrow_response_cookies, narrow_response_headers
from litestar.connection import Request
Expand Down Expand Up @@ -61,6 +63,14 @@

__all__ = ("HTTPRouteHandler", "route")

def _wrap_layered_hooks(hooks: list[Callable]) -> AsyncAnyCallable | None:
"""Given a list of callables, starting from the end, """
if not hooks:
return None
if 'parent' in inspect.signature(hooks[-1]).parameters:
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))

Check warning on line 71 in litestar/handlers/http_handlers/base.py

View check run for this annotation

Codecov / codecov/patch

litestar/handlers/http_handlers/base.py#L71

Added line #L71 was not covered by tests
return hooks[-1]


class ResponseHandlerMap(TypedDict):
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
Expand Down Expand Up @@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
"""
if self._resolved_before_request is Empty:
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
return cast("AsyncAnyCallable | None", self._resolved_before_request)

def resolve_after_response(self) -> AsyncAnyCallable | None:
Expand All @@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
for layer in self.ownership_layers
if layer.after_response
]
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)

return cast("AsyncAnyCallable | None", self._resolved_after_response)

Expand Down
5 changes: 3 additions & 2 deletions litestar/handlers/http_handlers/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def __init__(
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
Defaults to ``None``.
before_request: A sync or async function called immediately before calling the route handler. Receives
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
bypassing the route handler.
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
response, bypassing the route handler.
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
of seconds (e.g. ``120``) to cache the response.
cache_control: A ``cache-control`` header of type
Expand Down
23 changes: 20 additions & 3 deletions litestar/types/callable_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar

if TYPE_CHECKING:
from typing_extensions import TypeAlias
Expand All @@ -23,12 +23,29 @@
AfterRequestHookHandler: TypeAlias = (
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
)
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"

AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"

class AfterResponseHookHandlerWithParentAsync(Protocol):
async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ...
class AfterResponseHookHandlerWithParentSync(Protocol):
def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ...

AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParentAsync | AfterResponseHookHandlerWithParentSync"

AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
AnyCallable: TypeAlias = Callable[..., Any]
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"

class BeforeRequestHookHandlerWithParentAsync(Protocol):
async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ...
class BeforeRequestHookHandlerWithParentSync(Protocol):
def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ...

BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParentSync | BeforeRequestHookHandlerWithParentAsync"

CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"
Expand Down

0 comments on commit adc478e

Please sign in to comment.