diff --git a/.gitignore b/.gitignore index 5bf2443ccb..0c3c190c57 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ .scannerwork/ .unasyncd_cache/ .venv/ +.venv* .vscode/ __pycache__/ build/ diff --git a/docs/conf.py b/docs/conf.py index eafe847470..e9a00e4fb0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,6 +57,7 @@ "opentelemetry": ("https://opentelemetry-python.readthedocs.io/en/latest/", None), "advanced-alchemy": ("https://docs.advanced-alchemy.jolt.rs/latest/", None), "jinja2": ("https://jinja.palletsprojects.com/en/latest/", None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), } napoleon_google_docstring = True @@ -226,6 +227,9 @@ ), re.compile(r"litestar\.dto.*"): re.compile(".*T|.*FieldDefinition|Empty"), re.compile(r"litestar\.template\.(config|TemplateConfig).*"): re.compile(".*EngineType"), + "litestar.concurrency.set_asyncio_executor": {"ThreadPoolExecutor"}, + "litestar.concurrency.get_asyncio_executor": {"ThreadPoolExecutor"}, + re.compile(r"litestar\.channels\.backends\.asyncpg.*"): {"asyncpg.connection.Connection"}, } # Do not warn about broken links to the following: diff --git a/docs/reference/channels/backends/asyncpg.rst b/docs/reference/channels/backends/asyncpg.rst new file mode 100644 index 0000000000..91d44ecdf1 --- /dev/null +++ b/docs/reference/channels/backends/asyncpg.rst @@ -0,0 +1,5 @@ +asyncpg +======= + +.. automodule:: litestar.channels.backends.asyncpg + :members: diff --git a/docs/reference/channels/backends/index.rst b/docs/reference/channels/backends/index.rst index 010ae7e509..02deff518a 100644 --- a/docs/reference/channels/backends/index.rst +++ b/docs/reference/channels/backends/index.rst @@ -6,3 +6,5 @@ backends base memory redis + psycopg + asyncpg diff --git a/docs/reference/channels/backends/psycopg.rst b/docs/reference/channels/backends/psycopg.rst new file mode 100644 index 0000000000..4a8163db60 --- /dev/null +++ b/docs/reference/channels/backends/psycopg.rst @@ -0,0 +1,5 @@ +psycopg +======= + +.. automodule:: litestar.channels.backends.psycopg + :members: diff --git a/docs/reference/concurrency.rst b/docs/reference/concurrency.rst new file mode 100644 index 0000000000..89bf990a58 --- /dev/null +++ b/docs/reference/concurrency.rst @@ -0,0 +1,5 @@ +cli +=== + +.. automodule:: litestar.concurrency + :members: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 637a1fe3d7..6fcd7bc88e 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -14,6 +14,7 @@ API reference connection contrib/index controller + concurrency data_extractors datastructures di diff --git a/docs/release-notes/changelog.rst b/docs/release-notes/changelog.rst index 7f2216288c..1226a44c57 100644 --- a/docs/release-notes/changelog.rst +++ b/docs/release-notes/changelog.rst @@ -3,6 +3,160 @@ 2.x Changelog ============= +.. changelog:: 2.5.0 + :date: 2024/01/06 + + .. change:: Fix serialization of custom types in exception responses + :type: bugfix + :issue: 2867 + :pr: 2941 + + Fix a bug that would lead to a :exc:`SerializationException` when custom types + were present in an exception response handled by the built-in exception + handlers. + + .. code-block:: python + + class Foo: + pass + + + @get() + def handler() -> None: + raise ValidationException(extra={"foo": Foo("bar")}) + + + app = Litestar(route_handlers=[handler], type_encoders={Foo: lambda foo: "foo"}) + + The cause was that, in examples like the one shown above, ``type_encoders`` + were not resolved properly from all layers by the exception handling middleware, + causing the serializer to throw an exception for an unknown type. + + .. change:: Fix SSE reverting to default ``event_type`` after 1st message + :type: bugfix + :pr: 2888 + :issue: 2877 + + The ``event_type`` set within an SSE returned from a handler would revert back + to a default after the first message sent: + + .. code-block:: python + + @get("/stream") + async def stream(self) -> ServerSentEvent: + async def gen() -> AsyncGenerator[str, None]: + c = 0 + while True: + yield f"
{c}
\n" + c += 1 + + return ServerSentEvent(gen(), event_type="my_event") + + In this example, the event type would only be ``my_event`` for the first + message, and fall back to a default afterwards. The implementation has been + fixed and will now continue sending the set event type for all messages. + + .. change:: Correctly handle single file upload validation when multiple files are specified + :type: bugfix + :pr: 2950 + :issue: 2939 + + Uploading a single file when the validation target allowed multiple would cause + a :exc:`ValidationException`: + + .. code-block:: python + + class FileUpload(Struct): + files: list[UploadFile] + + + @post(path="/") + async def upload_files_object( + data: Annotated[FileUpload, Body(media_type=RequestEncodingType.MULTI_PART)] + ) -> list[str]: + pass + + + This could would only allow for 2 or more files to be sent, and otherwise throw + an exception. + + .. change:: Fix trailing messages after unsubscribe in channels + :type: bugfix + :pr: 2894 + + Fix a bug that would allow some channels backend to receive messages from a + channel it just unsubscribed from, for a short period of time, due to how the + different brokers handle unsubscribes. + + .. code-block:: python + + await backend.subscribe(["foo", "bar"]) # subscribe to two channels + await backend.publish( + b"something", ["foo"] + ) # publish a message to a channel we're subscribed to + + # start the stream after publishing. Depending on the backend + # the previously published message might be in the stream + event_generator = backend.stream_events() + + # unsubscribe from the channel we previously published to + await backend.unsubscribe(["foo"]) + + # this should block, as we expect messages from channels + # we unsubscribed from to not appear in the stream anymore + print(anext(event_generator)) + + Backends affected by this were in-memory, Redis PubSub and asyncpg. The Redis + stream and psycopg backends were not affected. + + .. change:: Postgres channels backends + :type: feature + :pr: 2803 + + Two new channel backends were added to bring Postgres support: + + :class:`~litestar.channels.backends.asyncpg.AsyncPgChannelsBackend`, using the + `asyncpg `_ driver and + :class:`~litestar.channels.backends.psycopg.PsycoPgChannelsBackend` using the + `psycopg3 `_ async driver. + + .. seealso:: + :doc:`/usage/channels` + + + .. change:: Add ``--schema`` and ``--exclude`` option to ``litestar route`` CLI command + :type: feature + :pr: 2886 + + Two new options were added to the ``litestar route`` CLI command: + + - ``--schema``, to include the routes serving OpenAPI schema and docs + - ``--exclude`` to exclude routes matching a specified pattern + + .. seealso:: + :ref:`usage/cli:routes` + + .. change:: Improve performance of threaded synchronous execution + :type: misc + :pr: 2937 + + Performance of threaded synchronous code was improved by using the async + library's native threading helpers instead of anyio. On asyncio, + :meth:`asyncio.loop.run_in_executor` is now used and on trio + :func:`trio.to_thread.run_sync`. + + Beneficiaries of these performance improvements are: + + - Synchronous route handlers making use of ``sync_to_thread=True`` + - Synchronous dependency providers making use of ``sync_to_thread=True`` + - Synchronous SSE generators + - :class:`~litestar.stores.file.FileStore` + - Large file uploads where the ``max_spool_size`` is exceeded and the spooled + temporary file has been rolled to disk + - :class:`~litestar.response.file.File` and + :class:`~litestar.response.file.ASGIFileResponse` + + .. changelog:: 2.4.5 :date: 2023/12/23 diff --git a/docs/usage/channels.rst b/docs/usage/channels.rst index cbf0ef2721..0f6e70ee36 100644 --- a/docs/usage/channels.rst +++ b/docs/usage/channels.rst @@ -413,6 +413,17 @@ implemented are: when history is needed +:class:`AsyncPgChannelsBackend <.asyncpg.AsyncPgChannelsBackend>` + A postgres backend using the + `asyncpg `_ driver + + +:class:`PsycoPgChannelsBackend <.psycopg.PsycoPgChannelsBackend>` + A postgres backend using the `psycopg3 `_ + async driver + + + Integrating with websocket handlers ----------------------------------- diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index 6bb943c7f8..0350f22270 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -185,6 +185,19 @@ The ``routes`` command displays a tree view of the routing table. litestar routes +Options +~~~~~~~ + ++-----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Flag | Description | ++=================+===========================================================================================================================================================+ +| ``--schema`` | Include default auto generated openAPI schema routes | ++-----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+ +| ``--exclude`` | Exclude endpoints from query with given regex patterns. Multiple excludes allowed. e.g., ``litestar routes --schema --exclude=routes/.* --exclude=[]`` | ++-----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------+ + + + .. image:: /images/cli/litestar_routes.png :alt: litestar info diff --git a/litestar/channels/backends/asyncpg.py b/litestar/channels/backends/asyncpg.py new file mode 100644 index 0000000000..77894c72db --- /dev/null +++ b/litestar/channels/backends/asyncpg.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from functools import partial +from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload + +import asyncpg + +from litestar.channels import ChannelsBackend +from litestar.exceptions import ImproperlyConfiguredException + + +class AsyncPgChannelsBackend(ChannelsBackend): + _listener_conn: asyncpg.Connection + + @overload + def __init__(self, dsn: str) -> None: + ... + + @overload + def __init__( + self, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]], + ) -> None: + ... + + def __init__( + self, + dsn: str | None = None, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None, + ) -> None: + if not (dsn or make_connection): + raise ImproperlyConfiguredException("Need to specify dsn or make_connection") + + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + self._connect = make_connection or partial(asyncpg.connect, dsn=dsn) + self._queue: asyncio.Queue[tuple[str, bytes]] | None = None + + async def on_startup(self) -> None: + self._queue = asyncio.Queue() + self._listener_conn = await self._connect() + + async def on_shutdown(self) -> None: + await self._listener_conn.close() + self._queue = None + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + dec_data = data.decode("utf-8") + + conn = await self._connect() + try: + for channel in channels: + await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data) + finally: + await conn.close() + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() + self._queue.task_done() + # an UNLISTEN may be in transit while we're getting here, so we double-check + # that we are actually supposed to deliver this message + if channel in self._subscribed_channels: + yield channel, message + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() + + def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None: + if not isinstance(payload, str): + raise RuntimeError("Invalid data received") + self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr] diff --git a/litestar/channels/backends/memory.py b/litestar/channels/backends/memory.py index e01f0d779c..a96a66bca7 100644 --- a/litestar/channels/backends/memory.py +++ b/litestar/channels/backends/memory.py @@ -63,10 +63,19 @@ async def unsubscribe(self, channels: Iterable[str]) -> None: async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: """Return a generator, iterating over events of subscribed channels as they become available""" - while self._queue: - yield await self._queue.get() + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() self._queue.task_done() + # if a message is published to a channel and the channel is then + # unsubscribed before retrieving that message from the stream, it can still + # end up here, so we double-check if we still are interested in this message + if channel in self._channels: + yield channel, message + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: """Return the event history of ``channel``, at most ``limit`` entries""" history = list(self._history[channel]) diff --git a/litestar/channels/backends/psycopg.py b/litestar/channels/backends/psycopg.py new file mode 100644 index 0000000000..14b53bcd1a --- /dev/null +++ b/litestar/channels/backends/psycopg.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Iterable + +import psycopg + +from .base import ChannelsBackend + + +def _safe_quote(ident: str) -> str: + return '"{}"'.format(ident.replace('"', '""')) # sourcery skip + + +class PsycoPgChannelsBackend(ChannelsBackend): + _listener_conn: psycopg.AsyncConnection + + def __init__(self, pg_dsn: str) -> None: + self._pg_dsn = pg_dsn + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + + async def on_startup(self) -> None: + self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True) + await self._exit_stack.enter_async_context(self._listener_conn) + + async def on_shutdown(self) -> None: + await self._exit_stack.aclose() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + dec_data = data.decode("utf-8") + async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn: + for channel in channels: + await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data)) + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + # can't use placeholders in LISTEN + await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore + + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + # can't use placeholders in UNLISTEN + await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + async for notify in self._listener_conn.notifies(): + yield notify.channel, notify.payload.encode("utf-8") + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() diff --git a/litestar/channels/backends/redis.py b/litestar/channels/backends/redis.py index c84c061c88..f03c9f2ffc 100644 --- a/litestar/channels/backends/redis.py +++ b/litestar/channels/backends/redis.py @@ -4,7 +4,7 @@ import sys if sys.version_info < (3, 9): - import importlib_resources + import importlib_resources # pyright: ignore else: import importlib.resources as importlib_resources from abc import ABC @@ -112,7 +112,10 @@ async def subscribe(self, channels: Iterable[str]) -> None: async def unsubscribe(self, channels: Iterable[str]) -> None: """Stop listening for events on ``channels``""" await self._pub_sub.unsubscribe(*channels) - if not self._pub_sub.subscribed: + # if we have no active subscriptions, or only subscriptions which are pending + # to be unsubscribed we consider the backend to be unsubscribed from all + # channels, so we reset the event + if not self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: self._has_subscribed.clear() async def publish(self, data: bytes, channels: Iterable[str]) -> None: @@ -138,9 +141,14 @@ async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: if message is None: continue - channel = message["channel"].decode() - data = message["data"] - yield channel, data + channel: str = message["channel"].decode() + data: bytes = message["data"] + # redis handles the unsubscibes with a queue; Unsubscribing doesn't mean the + # unsubscribe will happen immediately after requesting it, so we could + # receive a message on a channel that, from a client's perspective, it's not + # subscribed to anymore + if channel.encode() in self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: + yield channel, data async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: """Not implemented""" @@ -234,8 +242,6 @@ async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: # We wait for subscribed channels, because we can't pass an empty dict to # xread and block for subscribers stream_keys = [self._make_key(c) for c in await self._get_subscribed_channels()] - if not stream_keys: - continue data: list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]] = await self._redis.xread( {key: stream_ids.get(key, 0) for key in stream_keys}, block=self._stream_sleep_no_subscriptions diff --git a/litestar/channels/plugin.py b/litestar/channels/plugin.py index ae7dcc78b3..59884454d4 100644 --- a/litestar/channels/plugin.py +++ b/litestar/channels/plugin.py @@ -311,10 +311,10 @@ async def _sub_worker(self) -> None: subscriber.put_nowait(payload) async def _on_startup(self) -> None: + await self._backend.on_startup() self._pub_queue = Queue() self._pub_task = create_task(self._pub_worker()) self._sub_task = create_task(self._sub_worker()) - await self._backend.on_startup() if self._channels: await self._backend.subscribe(list(self._channels)) @@ -336,11 +336,13 @@ async def _on_shutdown(self) -> None: self._sub_task.cancel() with suppress(CancelledError): await self._sub_task + self._sub_task = None if self._pub_task: self._pub_task.cancel() with suppress(CancelledError): await self._pub_task + self._sub_task = None await self._backend.on_shutdown() diff --git a/litestar/channels/subscriber.py b/litestar/channels/subscriber.py index 7e2cba56ac..b358bc4727 100644 --- a/litestar/channels/subscriber.py +++ b/litestar/channels/subscriber.py @@ -104,8 +104,9 @@ def _start_in_background(self, on_event: EventCallback) -> None: Args: on_event: Callback to invoke with the event data for every event """ - if self._task is None: - self._task = asyncio.create_task(self._worker(on_event)) + if self._task is not None: + raise RuntimeError("Subscriber is already running") + self._task = asyncio.create_task(self._worker(on_event)) @property def is_running(self) -> bool: diff --git a/litestar/cli/_utils.py b/litestar/cli/_utils.py index b7c129a832..f158b1180f 100644 --- a/litestar/cli/_utils.py +++ b/litestar/cli/_utils.py @@ -4,6 +4,7 @@ import importlib import inspect import os +import re import sys from dataclasses import dataclass from datetime import datetime, timedelta, timezone @@ -65,6 +66,8 @@ if TYPE_CHECKING: + from litestar.openapi import OpenAPIConfig + from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute from litestar.types import AnyCallable @@ -539,3 +542,34 @@ def _generate_self_signed_cert(certfile_path: Path, keyfile_path: Path, common_n encryption_algorithm=serialization.NoEncryption(), ) ) + + +def remove_routes_with_patterns( + routes: list[HTTPRoute | ASGIRoute | WebSocketRoute], patterns: tuple[str, ...] +) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: + regex_routes = [] + valid_patterns = [] + for pattern in patterns: + try: + check_pattern = re.compile(pattern) + valid_patterns.append(check_pattern) + except re.error as e: + console.print(f"Error: {e}. Invalid regex pattern supplied: '{pattern}'. Omitting from querying results.") + + for route in routes: + checked_pattern_route_matches = [] + for pattern_compile in valid_patterns: + matches = pattern_compile.match(route.path) + checked_pattern_route_matches.append(matches) + + if not any(checked_pattern_route_matches): + regex_routes.append(route) + + return regex_routes + + +def remove_default_schema_routes( + routes: list[HTTPRoute | ASGIRoute | WebSocketRoute], openapi_config: OpenAPIConfig +) -> list[HTTPRoute | ASGIRoute | WebSocketRoute]: + schema_path = openapi_config.openapi_controller.path + return remove_routes_with_patterns(routes, (schema_path,)) diff --git a/litestar/cli/commands/core.py b/litestar/cli/commands/core.py index c6fab9f9ab..3cd438ab19 100644 --- a/litestar/cli/commands/core.py +++ b/litestar/cli/commands/core.py @@ -10,12 +10,15 @@ from rich.tree import Tree +from litestar.app import DEFAULT_OPENAPI_CONFIG from litestar.cli._utils import ( RICH_CLICK_INSTALLED, UVICORN_INSTALLED, LitestarEnv, console, create_ssl_files, + remove_default_schema_routes, + remove_routes_with_patterns, show_app_info, validate_ssl_file_paths, ) @@ -260,12 +263,20 @@ def run_command( @command(name="routes") -def routes_command(app: Litestar) -> None: # pragma: no cover +@option("--schema", help="Include schema routes", is_flag=True, default=False) +@option("--exclude", help="routes to exclude via regex", type=str, is_flag=False, multiple=True) +def routes_command(app: Litestar, exclude: tuple[str, ...], schema: bool) -> None: # pragma: no cover """Display information about the application's routes.""" tree = Tree("", hide_root=True) - - for route in sorted(app.routes, key=lambda r: r.path): + sorted_routes = sorted(app.routes, key=lambda r: r.path) + if not schema: + openapi_config = app.openapi_config or DEFAULT_OPENAPI_CONFIG + sorted_routes = remove_default_schema_routes(sorted_routes, openapi_config) + if exclude is not None: + sorted_routes = remove_routes_with_patterns(sorted_routes, exclude) + + for route in sorted_routes: if isinstance(route, HTTPRoute): branch = tree.add(f"[green]{route.path}[/green] (HTTP)") for handler in route.route_handlers: diff --git a/litestar/concurrency.py b/litestar/concurrency.py new file mode 100644 index 0000000000..90eadbf724 --- /dev/null +++ b/litestar/concurrency.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import asyncio +import contextvars +from functools import partial +from typing import TYPE_CHECKING, Callable, TypeVar + +import sniffio +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from concurrent.futures import ThreadPoolExecutor + + import trio + + +T = TypeVar("T") +P = ParamSpec("P") + + +__all__ = ( + "sync_to_thread", + "set_asyncio_executor", + "get_asyncio_executor", + "set_trio_capacity_limiter", + "get_trio_capacity_limiter", +) + + +class _State: + EXECUTOR: ThreadPoolExecutor | None = None + LIMITER: trio.CapacityLimiter | None = None + + +async def _run_sync_asyncio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + ctx = contextvars.copy_context() + bound_fn = partial(ctx.run, fn, *args, **kwargs) + return await asyncio.get_running_loop().run_in_executor(get_asyncio_executor(), bound_fn) # pyright: ignore + + +async def _run_sync_trio(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + import trio + + return await trio.to_thread.run_sync(partial(fn, *args, **kwargs), limiter=get_trio_capacity_limiter()) + + +async def sync_to_thread(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + """Run the synchronous callable ``fn`` asynchronously in a worker thread. + + When called from asyncio, uses :meth:`asyncio.loop.run_in_executor` to + run the callable. No executor is specified by default so the current loop's executor + is used. A specific executor can be set using + :func:`~litestar.concurrency.set_asyncio_executor`. This does not affect the loop's + default executor. + + When called from trio, uses :func:`trio.to_thread.run_sync` to run the callable. No + capacity limiter is specified by default, but one can be set using + :func:`~litestar.concurrency.set_trio_capacity_limiter`. This does not affect trio's + default capacity limiter. + """ + if (library := sniffio.current_async_library()) == "asyncio": + return await _run_sync_asyncio(fn, *args, **kwargs) + + if library == "trio": + return await _run_sync_trio(fn, *args, **kwargs) + + raise RuntimeError("Unsupported async library or not in async context") + + +def set_asyncio_executor(executor: ThreadPoolExecutor | None) -> None: + """Set the executor in which synchronous callables will be run within an asyncio + context + """ + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError("Cannot set executor from running loop") + + _State.EXECUTOR = executor + + +def get_asyncio_executor() -> ThreadPoolExecutor | None: + """Get the executor in which synchronous callables will be run within an asyncio + context + """ + return _State.EXECUTOR + + +def set_trio_capacity_limiter(limiter: trio.CapacityLimiter | None) -> None: + """Set the capacity limiter used when running synchronous callable within a trio + context + """ + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError("Cannot set limiter while in async context") + + _State.LIMITER = limiter + + +def get_trio_capacity_limiter() -> trio.CapacityLimiter | None: + """Get the capacity limiter used when running synchronous callable within a trio + context + """ + return _State.LIMITER diff --git a/litestar/contrib/mako.py b/litestar/contrib/mako.py index b875a1c2fd..859a814723 100644 --- a/litestar/contrib/mako.py +++ b/litestar/contrib/mako.py @@ -120,8 +120,7 @@ def register_template_callable( """ self._template_callables.append((key, template_callable)) - @staticmethod - def render_string(template_string: str, context: Mapping[str, Any]) -> str: # pyright: ignore + def render_string(self, template_string: str, context: Mapping[str, Any]) -> str: # pyright: ignore """Render a template from a string with the given context. Args: diff --git a/litestar/datastructures/upload_file.py b/litestar/datastructures/upload_file.py index 07b425514f..09ad2d32ab 100644 --- a/litestar/datastructures/upload_file.py +++ b/litestar/datastructures/upload_file.py @@ -2,8 +2,7 @@ from tempfile import SpooledTemporaryFile -from anyio.to_thread import run_sync - +from litestar.concurrency import sync_to_thread from litestar.constants import ONE_MEGABYTE __all__ = ("UploadFile",) @@ -59,7 +58,7 @@ async def write(self, data: bytes) -> int: None """ if self.rolled_to_disk: - return await run_sync(self.file.write, data) + return await sync_to_thread(self.file.write, data) return self.file.write(data) async def read(self, size: int = -1) -> bytes: @@ -72,7 +71,7 @@ async def read(self, size: int = -1) -> bytes: Byte string. """ if self.rolled_to_disk: - return await run_sync(self.file.read, size) + return await sync_to_thread(self.file.read, size) return self.file.read(size) async def seek(self, offset: int) -> int: @@ -85,7 +84,7 @@ async def seek(self, offset: int) -> int: None. """ if self.rolled_to_disk: - return await run_sync(self.file.seek, offset) + return await sync_to_thread(self.file.seek, offset) return self.file.seek(offset) async def close(self) -> None: @@ -95,7 +94,7 @@ async def close(self) -> None: None. """ if self.rolled_to_disk: - return await run_sync(self.file.close) + return await sync_to_thread(self.file.close) return self.file.close() def __repr__(self) -> str: diff --git a/litestar/file_system.py b/litestar/file_system.py index b86bf2cdd8..d7655485e1 100644 --- a/litestar/file_system.py +++ b/litestar/file_system.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING, Any, AnyStr, cast from anyio import AsyncFile, Path, open_file -from anyio.to_thread import run_sync +from litestar.concurrency import sync_to_thread from litestar.exceptions import InternalServerException, NotAuthorizedException from litestar.types.file_types import FileSystemProtocol from litestar.utils.predicates import is_async_callable @@ -77,7 +77,7 @@ async def info(self, path: PathType) -> FileInfo: awaitable = ( self.file_system.info(str(path)) if is_async_callable(self.file_system.info) - else run_sync(self.file_system.info, str(path)) + else sync_to_thread(self.file_system.info, str(path)) ) return cast("FileInfo", await awaitable) except FileNotFoundError as e: @@ -113,7 +113,7 @@ async def open( buffering=buffering, ), ) - return AsyncFile(await run_sync(self.file_system.open, file, mode, buffering)) # type: ignore + return AsyncFile(await sync_to_thread(self.file_system.open, file, mode, buffering)) # type: ignore[arg-type] except PermissionError as e: raise NotAuthorizedException(f"failed to open {file} due to missing permissions") from e except OSError as e: diff --git a/litestar/response/sse.py b/litestar/response/sse.py index 4ff0331b15..e6bd826ac3 100644 --- a/litestar/response/sse.py +++ b/litestar/response/sse.py @@ -5,8 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator -from anyio.to_thread import run_sync - +from litestar.concurrency import sync_to_thread from litestar.exceptions import ImproperlyConfiguredException from litestar.response.streaming import Stream from litestar.utils import AsyncIteratorWrapper @@ -85,7 +84,7 @@ def _call_next(self) -> bytes: async def _async_generator(self) -> AsyncGenerator[bytes, None]: while True: try: - yield await run_sync(self._call_next) + yield await sync_to_thread(self._call_next) except ValueError: async for value in self.content_async_iterator: d = self.ensure_bytes(value, DEFAULT_SEPARATOR) diff --git a/litestar/stores/file.py b/litestar/stores/file.py index a3ebba3241..25c52eb6b5 100644 --- a/litestar/stores/file.py +++ b/litestar/stores/file.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING from anyio import Path -from anyio.to_thread import run_sync + +from litestar.concurrency import sync_to_thread from .base import NamespacedStore, StorageObject @@ -73,7 +74,7 @@ def _write_sync(self, target_file: Path, storage_obj: StorageObject) -> None: pass async def _write(self, target_file: Path, storage_obj: StorageObject) -> None: - await run_sync(self._write_sync, target_file, storage_obj) + await sync_to_thread(self._write_sync, target_file, storage_obj) async def set(self, key: str, value: str | bytes, expires_in: int | timedelta | None = None) -> None: """Set a value. @@ -140,7 +141,7 @@ async def delete_all(self) -> None: This deletes and recreates :attr:`FileStore.path` """ - await run_sync(shutil.rmtree, self.path) + await sync_to_thread(shutil.rmtree, self.path) await self.path.mkdir(exist_ok=True) async def delete_expired(self) -> None: diff --git a/litestar/utils/helpers.py b/litestar/utils/helpers.py index e7ae22159e..1c75eeb9c1 100644 --- a/litestar/utils/helpers.py +++ b/litestar/utils/helpers.py @@ -98,6 +98,6 @@ def get_exception_group() -> type[BaseException]: try: return cast("type[BaseException]", ExceptionGroup) # type:ignore[name-defined] except NameError: - from exceptiongroup import ExceptionGroup as _ExceptionGroup + from exceptiongroup import ExceptionGroup as _ExceptionGroup # pyright: ignore return cast("type[BaseException]", _ExceptionGroup) diff --git a/litestar/utils/sync.py b/litestar/utils/sync.py index 5f190d5182..02acabfc4d 100644 --- a/litestar/utils/sync.py +++ b/litestar/utils/sync.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import partial from typing import ( AsyncGenerator, Awaitable, @@ -11,9 +10,9 @@ TypeVar, ) -from anyio.to_thread import run_sync from typing_extensions import ParamSpec +from litestar.concurrency import sync_to_thread from litestar.utils.predicates import is_async_callable __all__ = ("ensure_async_callable", "AsyncIteratorWrapper", "AsyncCallable", "is_async_callable") @@ -43,7 +42,7 @@ def __init__(self, fn: Callable[P, T]) -> None: # pyright: ignore self.func = fn def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: # pyright: ignore - return run_sync(partial(self.func, **kwargs), *args) # pyright: ignore + return sync_to_thread(self.func, *args, **kwargs) # pyright: ignore class AsyncIteratorWrapper(Generic[T]): @@ -69,7 +68,7 @@ def _call_next(self) -> T: async def _async_generator(self) -> AsyncGenerator[T, None]: while True: try: - yield await run_sync(self._call_next) + yield await sync_to_thread(self._call_next) except ValueError: return diff --git a/pdm.lock b/pdm.lock index 6c5979aba8..8bc10c6181 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "standard", "jwt", "pydantic", "cli", "picologging", "dev-contrib", "piccolo", "prometheus", "dev", "mako", "test", "brotli", "cryptography", "linting", "attrs", "opentelemetry", "docs", "redis", "sqlalchemy", "full", "annotated-types", "jinja", "structlog", "minijinja"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:8d1f58bfc4ff70c92de422ef99f3c005e9360a00147adf245ce9c5578cec162d" +content_hash = "sha256:63cac5a26843dd6138a7ba24d0ce45b2880331e4f6a054a95f0308c4fa8f9531" [[package]] name = "accessible-pygments" @@ -284,6 +284,20 @@ files = [ {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, ] +[[package]] +name = "backports-zoneinfo" +version = "0.2.1" +requires_python = ">=3.6" +summary = "Backport of the standard library zoneinfo module" +files = [ + {file = "backports.zoneinfo-0.2.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:8961c0f32cd0336fb8e8ead11a1f8cd99ec07145ec2931122faaac1c8f7fd987"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:e81b76cace8eda1fca50e345242ba977f9be6ae3945af8d46326d776b4cf78d1"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7b0a64cda4145548fed9efc10322770f929b944ce5cee6c0dfe0c87bf4c0c8c9"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win32.whl", hash = "sha256:1b13e654a55cd45672cb54ed12148cd33628f672548f373963b0bff67b217328"}, + {file = "backports.zoneinfo-0.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4a0f800587060bf8880f954dbef70de6c11bbe59c673c3d818921f042f9954a6"}, + {file = "backports.zoneinfo-0.2.1.tar.gz", hash = "sha256:fadbfe37f74051d024037f223b8e001611eac868b5c5b06144ef4d8b799862f2"}, +] + [[package]] name = "beanie" version = "1.23.6" @@ -2043,6 +2057,182 @@ files = [ {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"}, ] +[[package]] +name = "psycopg" +version = "3.1.16" +requires_python = ">=3.7" +summary = "PostgreSQL database adapter for Python" +dependencies = [ + "backports-zoneinfo>=0.2.0; python_version < \"3.9\"", + "typing-extensions>=4.1", + "tzdata; sys_platform == \"win32\"", +] +files = [ + {file = "psycopg-3.1.16-py3-none-any.whl", hash = "sha256:0bfe9741f4fb1c8115cadd8fe832fa91ac277e81e0652ff7fa1400f0ef0f59ba"}, + {file = "psycopg-3.1.16.tar.gz", hash = "sha256:a34d922fd7df3134595e71c3428ba6f1bd5f4968db74857fe95de12db2d6b763"}, +] + +[[package]] +name = "psycopg-binary" +version = "3.1.16" +requires_python = ">=3.7" +summary = "PostgreSQL database adapter for Python -- C optimisation distribution" +files = [ + {file = "psycopg_binary-3.1.16-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e08e333366f8583c7bee33ca6a27f84b76e05ee4e9f9f327a48e3ff81386261d"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a18dfcf7eb3db698eb7a38b4a0e82bf5b76a7bc0079068c5837df70b965570f8"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db99192d9f448829322c4f59a584994ce747b8d586ec65788b4c65f7166cfe43"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f6053fe95596e2f67ff2c9464ea23032c748695a3b79060ca01ef878b0ea0f2"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e6092ec21c08ed4ae4ff343c93a3bbb1d39c87dee181860ce40fa3b5c46f4ae"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f81e880d1bd935433efab1c2883a02031df84e739eadcb2c6a715e9c2f41c19"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:430f8843e381199cdc39ce9506a2cdbc27a569c99a0d80193844c787ce7de94d"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:92bda36f0570a5f9a3d6aeb897bad219f1f23fc4e1d0e7780935798771efb536"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b256d500ec0121ad7875bc3539c43c82dc004535d55256a13c49df2d43f07ad8"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:699737cecf675e1eb70b892b1995456db4016eff7189a3ad9325dca5b6715cc3"}, + {file = "psycopg_binary-3.1.16-cp310-cp310-win_amd64.whl", hash = "sha256:5e0885bcd7d9a0c0043be83d6a214069356c640d42496de798d901d0a16a34e7"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4ee8be32eb8b813ef37c5f5968fe03fdddc9a6f0129190f97f6491c798a1ef57"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8f8fb9677fb7873daf9797207e72e9275f61e769a308c4ea8f55dfd3153ebae7"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a611d7256493ee5bb73a070c9c60206af415be6aee01243c186fc03f1eb1a48"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d267cc92f0f0a9ea6c8ef058e95c85e58133d06c06f4ed48d63fc256aef166ab"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e996b38ffeffbaa06d236bbeab5168d33eea95941cf74de1daa0b008333861b1"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8429017cd7a3ef4699bee4ff8125a5e30b26882b817a178608d73e69fb727ab9"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a7d3b2ea267e7676b3693799fadf941c672f5727fae4947efa1f0cc6e25b672c"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d8290cfd475fadf935da0900dc91b845fe92f792e6d53039c0df82f9049a84ad"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:72539a0c6b9a2a9be2acca993df17f4baaa0ed00f1d76b65733725286e3e3304"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1078370a93eaef1dc5aed540055d50cbe37e9154342f3a3d73fd768a6199344d"}, + {file = "psycopg_binary-3.1.16-cp311-cp311-win_amd64.whl", hash = "sha256:adca24d273fe81ecab2312309db547b345155ec50d15676e2df82b8c5409eb06"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e1c416a7c2a699c3e5ba031357682ebca92bd58f399e553173ab5d67cc71cbc5"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e951a8cc7cf919fdc817a28d57160e7286011a4a45dcad3be21f3e4feba8be1a"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eaa02fe8aa9ef8c8743919fdbc92c04b0ee8c43f3d65e53f24d355776c52fb3"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e23375c14c22ce8fd26d057ac4ab827de79aafced173c68a4c0b03520ea02c70"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84472e5c83e805d4c491f331061cbae3ea4e62f80a480fc4b32200be72262ffd"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b0f824565d1dc325c74c076efd5ba842b86219f8bc1b8048c8816621a8b268c"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6b856d44531475488e773ac78d2a7a91c0909a1e8bdbd20d3ebdbdce1868c9a0"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:198c4f16f806f7d2ad0c4a5b774652e17861b55249efb4e344049b1fcf9a24af"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:b23d4b86acba2d745763ee0801821af1c42b127d8df75b903b7e7ca7c5f6400c"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2cfd857f1085c59da592090f2fa0751da30b67dcafea2ac52c4b404678406aae"}, + {file = "psycopg_binary-3.1.16-cp312-cp312-win_amd64.whl", hash = "sha256:46c9cca48d459d8df71fda4eef7d94a189b8333f4bc3cf1d170c1796fcbbc8cd"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b22e2dad291a79d7a31b304866fd125038ef7fe378aba9698de0e1804a863c9"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d9e1768c46c595a8177cb709c99626c3cefbd12c2e46eb54323efd8ac4a7fc2d"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8eaabc8dd2d364e1b43d3a25188356191a45abb687b77016544f6847b3fcd73a"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cda744c43b09342b1a8b5aace13d3284c1f5ddbfcefa2d385f703337503a060"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cdaf56adc9cc56df7a05e8f097a776939ba49d5e6afc907ba7b404d8bd21c89"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7232116fc5d4e0274114f152bdb9df089895d4c70f7c03268cab0a4c48a28d04"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6f03239d7c18666f5d6ca82ea972235de4d4d3604287098af6cdc256b76a0ca5"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:edd1b630652bdfff84662b46d11878fbab8ab2966003c1876fcde56650e99e3f"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:481e9dafca1ed9532552e097105e6664ee7f14686270ed0ee0b1d6c78c2cdb11"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d43aa3aa55b5fa964ffa78cf6abdbd51ff33a759f290e9159a9f974ffa3178fa"}, + {file = "psycopg_binary-3.1.16-cp38-cp38-win_amd64.whl", hash = "sha256:51e66b282d8689bc33d81bde3a1e14d0c88a39200c2d9436b028b394d24f1f99"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfae154f3c88e67f3ed592765ad56531b6076acfe80796e28cccc05727c1cf5b"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f9f4bc3d366951359a68833c8031cc83faf5084b3bc80dd2d24f0add593d4418"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a37d682d7ff57cc2573b1011740ef1566749fc94ae6ac1456405510592735c0a"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0be876e3a8ee359f6a985b662c6b02a094a50b37adf1bd756a655004bddf167a"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f79192b0edd60ef24acb0af5b83319cbb65d4187576757b690646b290de8307"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcc5996b1db4e7fb948ea47b610456df317625d92474c779a20f92ca8cbcec92"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3f2ceb04f8137462f9312a324bea5402de0a4f0503cd5442f4264911e4b6265b"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:47517d2da63bb10c80c2cf35c80a936db79636534849524fd57940b5f0bbd7bd"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:2a6bd83d0b934aa03897e93acb6897972ccc3827ae61c903589bc92ed423f75d"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:08fb94928e785571ac90d3ab9e09f2721e0d895c2504ecfb8de91c5ea807b267"}, + {file = "psycopg_binary-3.1.16-cp39-cp39-win_amd64.whl", hash = "sha256:cf13807b61315130a59ea8d0950bda2ac875bae9fadc0b1a9aca9b4ef6d62c7b"}, +] + +[[package]] +name = "psycopg-pool" +version = "3.2.0" +requires_python = ">=3.8" +summary = "Connection Pool for Psycopg" +dependencies = [ + "typing-extensions>=3.10", +] +files = [ + {file = "psycopg-pool-3.2.0.tar.gz", hash = "sha256:2e857bb6c120d012dba240e30e5dff839d2d69daf3e962127ce6b8e40594170e"}, + {file = "psycopg_pool-3.2.0-py3-none-any.whl", hash = "sha256:73371d4e795d9363c7b496cbb2dfce94ee8fbf2dcdc384d0a937d1d9d8bdd08d"}, +] + +[[package]] +name = "psycopg2-binary" +version = "2.9.9" +requires_python = ">=3.7" +summary = "psycopg2 - Python-PostgreSQL Database Adapter" +files = [ + {file = "psycopg2-binary-2.9.9.tar.gz", hash = "sha256:7f01846810177d829c7692f1f5ada8096762d9172af1b1a28d4ab5b77c923c1c"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c2470da5418b76232f02a2fcd2229537bb2d5a7096674ce61859c3229f2eb202"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c6af2a6d4b7ee9615cbb162b0738f6e1fd1f5c3eda7e5da17861eacf4c717ea7"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75723c3c0fbbf34350b46a3199eb50638ab22a0228f93fb472ef4d9becc2382b"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83791a65b51ad6ee6cf0845634859d69a038ea9b03d7b26e703f94c7e93dbcf9"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ef4854e82c09e84cc63084a9e4ccd6d9b154f1dbdd283efb92ecd0b5e2b8c84"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1184ab8f113e8d660ce49a56390ca181f2981066acc27cf637d5c1e10ce46e"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d2997c458c690ec2bc6b0b7ecbafd02b029b7b4283078d3b32a852a7ce3ddd98"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b58b4710c7f4161b5e9dcbe73bb7c62d65670a87df7bcce9e1faaad43e715245"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0c009475ee389757e6e34611d75f6e4f05f0cf5ebb76c6037508318e1a1e0d7e"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8dbf6d1bc73f1d04ec1734bae3b4fb0ee3cb2a493d35ede9badbeb901fb40f6f"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-win32.whl", hash = "sha256:3f78fd71c4f43a13d342be74ebbc0666fe1f555b8837eb113cb7416856c79682"}, + {file = "psycopg2_binary-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:876801744b0dee379e4e3c38b76fc89f88834bb15bf92ee07d94acd06ec890a0"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee825e70b1a209475622f7f7b776785bd68f34af6e7a46e2e42f27b659b5bc26"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1ea665f8ce695bcc37a90ee52de7a7980be5161375d42a0b6c6abedbf0d81f0f"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:143072318f793f53819048fdfe30c321890af0c3ec7cb1dfc9cc87aa88241de2"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c332c8d69fb64979ebf76613c66b985414927a40f8defa16cf1bc028b7b0a7b0"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7fc5a5acafb7d6ccca13bfa8c90f8c51f13d8fb87d95656d3950f0158d3ce53"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977646e05232579d2e7b9c59e21dbe5261f403a88417f6a6512e70d3f8a046be"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b6356793b84728d9d50ead16ab43c187673831e9d4019013f1402c41b1db9b27"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bc7bb56d04601d443f24094e9e31ae6deec9ccb23581f75343feebaf30423359"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:77853062a2c45be16fd6b8d6de2a99278ee1d985a7bd8b103e97e41c034006d2"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:78151aa3ec21dccd5cdef6c74c3e73386dcdfaf19bced944169697d7ac7482fc"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, + {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e6f98446430fdf41bd36d4faa6cb409f5140c1c2cf58ce0bbdaf16af7d3f119"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c77e3d1862452565875eb31bdb45ac62502feabbd53429fdc39a1cc341d681ba"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:60989127da422b74a04345096c10d416c2b41bd7bf2a380eb541059e4e999980"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:246b123cc54bb5361588acc54218c8c9fb73068bf227a4a531d8ed56fa3ca7d6"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34eccd14566f8fe14b2b95bb13b11572f7c7d5c36da61caf414d23b91fcc5d94"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18d0ef97766055fec15b5de2c06dd8e7654705ce3e5e5eed3b6651a1d2a9a152"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3f82c171b4ccd83bbaf35aa05e44e690113bd4f3b7b6cc54d2219b132f3ae55"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ead20f7913a9c1e894aebe47cccf9dc834e1618b7aa96155d2091a626e59c972"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ca49a8119c6cbd77375ae303b0cfd8c11f011abbbd64601167ecca18a87e7cdd"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:323ba25b92454adb36fa425dc5cf6f8f19f78948cbad2e7bc6cdf7b0d7982e59"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:1236ed0952fbd919c100bc839eaa4a39ebc397ed1c08a97fc45fee2a595aa1b3"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:729177eaf0aefca0994ce4cffe96ad3c75e377c7b6f4efa59ebf003b6d398716"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-win32.whl", hash = "sha256:804d99b24ad523a1fe18cc707bf741670332f7c7412e9d49cb5eab67e886b9b5"}, + {file = "psycopg2_binary-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:a6cdcc3ede532f4a4b96000b6362099591ab4a3e913d70bcbac2b56c872446f7"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:72dffbd8b4194858d0941062a9766f8297e8868e1dd07a7b36212aaa90f49472"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:30dcc86377618a4c8f3b72418df92e77be4254d8f89f14b8e8f57d6d43603c0f"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31a34c508c003a4347d389a9e6fcc2307cc2150eb516462a7a17512130de109e"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15208be1c50b99203fe88d15695f22a5bed95ab3f84354c494bcb1d08557df67"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1873aade94b74715be2246321c8650cabf5a0d098a95bab81145ffffa4c13876"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a58c98a7e9c021f357348867f537017057c2ed7f77337fd914d0bedb35dace7"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4686818798f9194d03c9129a4d9a702d9e113a89cb03bffe08c6cf799e053291"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ebdc36bea43063116f0486869652cb2ed7032dbc59fbcb4445c4862b5c1ecf7f"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ca08decd2697fdea0aea364b370b1249d47336aec935f87b8bbfd7da5b2ee9c1"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac05fb791acf5e1a3e39402641827780fe44d27e72567a000412c648a85ba860"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-win32.whl", hash = "sha256:9dba73be7305b399924709b91682299794887cbbd88e38226ed9f6712eabee90"}, + {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, +] + +[[package]] +name = "psycopg" +version = "3.1.16" +extras = ["binary", "pool"] +requires_python = ">=3.7" +summary = "PostgreSQL database adapter for Python" +dependencies = [ + "psycopg-binary==3.1.16; implementation_name != \"pypy\"", + "psycopg-pool", + "psycopg==3.1.16", +] +files = [ + {file = "psycopg-3.1.16-py3-none-any.whl", hash = "sha256:0bfe9741f4fb1c8115cadd8fe832fa91ac277e81e0652ff7fa1400f0ef0f59ba"}, + {file = "psycopg-3.1.16.tar.gz", hash = "sha256:a34d922fd7df3134595e71c3428ba6f1bd5f4968db74857fe95de12db2d6b763"}, +] + [[package]] name = "pyasn1" version = "0.5.1" @@ -3343,6 +3533,16 @@ files = [ {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, ] +[[package]] +name = "tzdata" +version = "2023.4" +requires_python = ">=2" +summary = "Provider of IANA time zone data" +files = [ + {file = "tzdata-2023.4-py2.py3-none-any.whl", hash = "sha256:aa3ace4329eeacda5b7beb7ea08ece826c28d761cda36e747cfbf97996d39bf3"}, + {file = "tzdata-2023.4.tar.gz", hash = "sha256:dd54c94f294765522c77399649b4fefd95522479a664a0cec87f41bebc6148c9"}, +] + [[package]] name = "urllib3" version = "2.1.0" diff --git a/pyproject.toml b/pyproject.toml index 1421cafefe..e6aabdc1c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ maintainers = [ name = "litestar" readme = "README.md" requires-python = ">=3.8,<4.0" -version = "2.4.5" +version = "2.5.0" [project.urls] Blog = "https://blog.litestar.dev" @@ -107,6 +107,9 @@ dev = [ "trio", "aiosqlite", "exceptiongroup; python_version < \"3.11\"", + "asyncpg>=0.29.0", + "psycopg[pool,binary]>=3.1.10", + "psycopg2-binary", ] dev-contrib = ["opentelemetry-sdk", "httpx-sse"] docs = [ diff --git a/tests/docker_service_fixtures.py b/tests/docker_service_fixtures.py index de71aace45..6efca36979 100644 --- a/tests/docker_service_fixtures.py +++ b/tests/docker_service_fixtures.py @@ -139,3 +139,8 @@ async def postgres_responsive(host: str) -> bool: return (await conn.fetchrow("SELECT 1"))[0] == 1 # type: ignore finally: await conn.close() + + +@pytest.fixture() +async def postgres_service(docker_services: DockerServiceRegistry) -> None: + await docker_services.start("postgres", check=postgres_responsive) diff --git a/tests/unit/test_channels/conftest.py b/tests/unit/test_channels/conftest.py index c95799143d..fdbaecbf28 100644 --- a/tests/unit/test_channels/conftest.py +++ b/tests/unit/test_channels/conftest.py @@ -3,7 +3,9 @@ import pytest from redis.asyncio import Redis as AsyncRedis +from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend from litestar.channels.backends.memory import MemoryChannelsBackend +from litestar.channels.backends.psycopg import PsycoPgChannelsBackend from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend @@ -26,7 +28,12 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item @pytest.fixture() def redis_stream_backend(redis_client: AsyncRedis) -> RedisChannelsStreamBackend: - return RedisChannelsStreamBackend(history=10, redis=redis_client, cap_streams_approximate=False) + return RedisChannelsStreamBackend(redis=redis_client, cap_streams_approximate=False, history=1) + + +@pytest.fixture() +def redis_stream_backend_with_history(redis_client: AsyncRedis) -> RedisChannelsStreamBackend: + return RedisChannelsStreamBackend(redis=redis_client, cap_streams_approximate=False, history=10) @pytest.fixture() @@ -36,4 +43,19 @@ def redis_pub_sub_backend(redis_client: AsyncRedis) -> RedisChannelsPubSubBacken @pytest.fixture() def memory_backend() -> MemoryChannelsBackend: + return MemoryChannelsBackend() + + +@pytest.fixture() +def memory_backend_with_history() -> MemoryChannelsBackend: return MemoryChannelsBackend(history=10) + + +@pytest.fixture() +def postgres_asyncpg_backend(postgres_service: None, docker_ip: str) -> AsyncPgChannelsBackend: + return AsyncPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423") + + +@pytest.fixture() +def postgres_psycopg_backend(postgres_service: None, docker_ip: str) -> PsycoPgChannelsBackend: + return PsycoPgChannelsBackend(f"postgres://postgres:super-secret@{docker_ip}:5423") diff --git a/tests/unit/test_channels/test_backends.py b/tests/unit/test_channels/test_backends.py index 17eb0baf8f..41f87dd396 100644 --- a/tests/unit/test_channels/test_backends.py +++ b/tests/unit/test_channels/test_backends.py @@ -3,14 +3,18 @@ import asyncio from datetime import timedelta from typing import AsyncGenerator, cast +from unittest.mock import AsyncMock, MagicMock import pytest from _pytest.fixtures import FixtureRequest from redis.asyncio.client import Redis from litestar.channels import ChannelsBackend +from litestar.channels.backends.asyncpg import AsyncPgChannelsBackend from litestar.channels.backends.memory import MemoryChannelsBackend -from litestar.channels.backends.redis import RedisChannelsPubSubBackend, RedisChannelsStreamBackend +from litestar.channels.backends.psycopg import PsycoPgChannelsBackend +from litestar.channels.backends.redis import RedisChannelsStreamBackend +from litestar.exceptions import ImproperlyConfiguredException from litestar.utils.compat import async_next @@ -18,6 +22,8 @@ params=[ pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")), pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")), + pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")), + pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")), pytest.param("memory_backend", id="memory"), ] ) @@ -25,6 +31,18 @@ def channels_backend_instance(request: FixtureRequest) -> ChannelsBackend: return cast(ChannelsBackend, request.getfixturevalue(request.param)) +@pytest.fixture( + params=[ + pytest.param( + "redis_stream_backend_with_history", id="redis:stream+history", marks=pytest.mark.xdist_group("redis") + ), + pytest.param("memory_backend_with_history", id="memory+history"), + ] +) +def channels_backend_instance_with_history(request: FixtureRequest) -> ChannelsBackend: + return cast(ChannelsBackend, request.getfixturevalue(request.param)) + + @pytest.fixture() async def channels_backend(channels_backend_instance: ChannelsBackend) -> AsyncGenerator[ChannelsBackend, None]: await channels_backend_instance.on_startup() @@ -32,6 +50,15 @@ async def channels_backend(channels_backend_instance: ChannelsBackend) -> AsyncG await channels_backend_instance.on_shutdown() +@pytest.fixture() +async def channels_backend_with_history( + channels_backend_instance_with_history: ChannelsBackend, +) -> AsyncGenerator[ChannelsBackend, None]: + await channels_backend_instance_with_history.on_startup() + yield channels_backend_instance_with_history + await channels_backend_instance_with_history.on_shutdown() + + @pytest.mark.parametrize("channels", [{"foo"}, {"foo", "bar"}]) async def test_pub_sub(channels_backend: ChannelsBackend, channels: set[str]) -> None: await channels_backend.subscribe(channels) @@ -44,6 +71,17 @@ async def test_pub_sub(channels_backend: ChannelsBackend, channels: set[str]) -> assert received == {(c, b"something") for c in channels} +async def test_pub_sub_unsubscribe(channels_backend: ChannelsBackend) -> None: + await channels_backend.subscribe(["foo", "bar"]) + await channels_backend.publish(b"something", ["foo"]) + + event_generator = channels_backend.stream_events() + await channels_backend.unsubscribe(["foo"]) + await channels_backend.publish(b"something", ["bar"]) + + assert await asyncio.wait_for(async_next(event_generator), timeout=0.01) == ("bar", b"something") + + async def test_pub_sub_no_subscriptions(channels_backend: ChannelsBackend) -> None: await channels_backend.publish(b"something", ["foo"]) @@ -54,7 +92,7 @@ async def test_pub_sub_no_subscriptions(channels_backend: ChannelsBackend) -> No @pytest.mark.flaky(reruns=5) # this should not really happen but just in case, we retry async def test_pub_sub_no_subscriptions_by_unsubscribes(channels_backend: ChannelsBackend) -> None: - await channels_backend.subscribe(["foo"]) + await channels_backend.subscribe(["foo", "bar"]) await channels_backend.publish(b"something", ["foo"]) event_generator = channels_backend.stream_events() @@ -80,30 +118,24 @@ async def test_unsubscribe_without_subscription(channels_backend: ChannelsBacken @pytest.mark.parametrize("history_limit,expected_history_length", [(None, 10), (1, 1), (5, 5), (10, 10)]) async def test_get_history( - channels_backend: ChannelsBackend, history_limit: int | None, expected_history_length: int + channels_backend_with_history: ChannelsBackend, history_limit: int | None, expected_history_length: int ) -> None: - if isinstance(channels_backend, RedisChannelsPubSubBackend): - pytest.skip("Redis pub/sub backend does not support history") - messages = [str(i).encode() for i in range(100)] for message in messages: - await channels_backend.publish(message, {"something"}) + await channels_backend_with_history.publish(message, {"something"}) - history = await channels_backend.get_history("something", history_limit) + history = await channels_backend_with_history.get_history("something", history_limit) expected_messages = messages[-expected_history_length:] assert len(history) == expected_history_length assert history == expected_messages -async def test_discards_history_entries(channels_backend: ChannelsBackend) -> None: - if isinstance(channels_backend, RedisChannelsPubSubBackend): - pytest.skip("Redis pub/sub backend does not support history") - +async def test_discards_history_entries(channels_backend_with_history: ChannelsBackend) -> None: for _ in range(20): - await channels_backend.publish(b"foo", {"bar"}) + await channels_backend_with_history.publish(b"foo", {"bar"}) - assert len(await channels_backend.get_history("bar")) == 10 + assert len(await channels_backend_with_history.get_history("bar")) == 10 @pytest.mark.xdist_group("redis") @@ -133,3 +165,56 @@ async def test_memory_publish_not_initialized_raises() -> None: with pytest.raises(RuntimeError): await backend.publish(b"foo", ["something"]) + + +@pytest.mark.xdist_group("postgres") +async def test_asyncpg_get_history(postgres_asyncpg_backend: AsyncPgChannelsBackend) -> None: + with pytest.raises(NotImplementedError): + await postgres_asyncpg_backend.get_history("something") + + +@pytest.mark.xdist_group("postgres") +async def test_psycopg_get_history(postgres_psycopg_backend: PsycoPgChannelsBackend) -> None: + with pytest.raises(NotImplementedError): + await postgres_psycopg_backend.get_history("something") + + +async def test_asyncpg_make_connection() -> None: + make_connection = AsyncMock() + + backend = AsyncPgChannelsBackend(make_connection=make_connection) + await backend.on_startup() + + make_connection.assert_awaited_once() + + +async def test_asyncpg_no_make_conn_or_dsn_passed_raises() -> None: + with pytest.raises(ImproperlyConfiguredException): + AsyncPgChannelsBackend() # type: ignore[call-overload] + + +def test_asyncpg_listener_raises_on_non_string_payload() -> None: + backend = AsyncPgChannelsBackend(make_connection=AsyncMock()) + with pytest.raises(RuntimeError): + backend._listener(connection=MagicMock(), pid=1, payload=b"abc", channel="foo") + + +async def test_asyncpg_backend_publish_before_startup_raises() -> None: + backend = AsyncPgChannelsBackend(make_connection=AsyncMock()) + + with pytest.raises(RuntimeError): + await backend.publish(b"foo", ["bar"]) + + +async def test_asyncpg_backend_stream_before_startup_raises() -> None: + backend = AsyncPgChannelsBackend(make_connection=AsyncMock()) + + with pytest.raises(RuntimeError): + await asyncio.wait_for(async_next(backend.stream_events()), timeout=0.01) + + +async def test_memory_backend_stream_before_startup_raises() -> None: + backend = MemoryChannelsBackend() + + with pytest.raises(RuntimeError): + await asyncio.wait_for(async_next(backend.stream_events()), timeout=0.01) diff --git a/tests/unit/test_channels/test_plugin.py b/tests/unit/test_channels/test_plugin.py index c6743e6102..7759c47555 100644 --- a/tests/unit/test_channels/test_plugin.py +++ b/tests/unit/test_channels/test_plugin.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from secrets import token_hex from typing import cast from unittest.mock import AsyncMock, MagicMock @@ -24,6 +25,8 @@ params=[ pytest.param("redis_pub_sub_backend", id="redis:pubsub", marks=pytest.mark.xdist_group("redis")), pytest.param("redis_stream_backend", id="redis:stream", marks=pytest.mark.xdist_group("redis")), + pytest.param("postgres_asyncpg_backend", id="postgres:asyncpg", marks=pytest.mark.xdist_group("postgres")), + pytest.param("postgres_psycopg_backend", id="postgres:psycopg", marks=pytest.mark.xdist_group("postgres")), pytest.param("memory_backend", id="memory"), ] ) @@ -31,6 +34,18 @@ def channels_backend(request: FixtureRequest) -> ChannelsBackend: return cast(ChannelsBackend, request.getfixturevalue(request.param)) +@pytest.fixture( + params=[ + pytest.param( + "redis_stream_backend_with_history", id="redis:stream+history", marks=pytest.mark.xdist_group("redis") + ), + pytest.param("memory_backend_with_history", id="memory+history"), + ] +) +def channels_backend_with_history(request: FixtureRequest) -> ChannelsBackend: + return cast(ChannelsBackend, request.getfixturevalue(request.param)) + + def test_channels_no_channels_arbitrary_not_allowed_raises(memory_backend: MemoryChannelsBackend) -> None: with pytest.raises(ImproperlyConfiguredException): ChannelsPlugin(backend=memory_backend) @@ -75,10 +90,11 @@ async def test_pub_sub_wait_published(channels_backend: ChannelsBackend) -> None @pytest.mark.flaky(reruns=10) -async def test_pub_sub_non_blocking(channels_backend: ChannelsBackend) -> None: +@pytest.mark.parametrize("channel", ["something", ["something"]]) +async def test_pub_sub_non_blocking(channels_backend: ChannelsBackend, channel: str | list[str]) -> None: async with ChannelsPlugin(backend=channels_backend, channels=["something"]) as plugin: - subscriber = await plugin.subscribe("something") - plugin.publish(b"foo", "something") + subscriber = await plugin.subscribe(channel) + plugin.publish(b"foo", channel) await asyncio.sleep(0.1) # give the worker time to process things @@ -119,7 +135,7 @@ def test_create_ws_route_handlers( @pytest.mark.flaky(reruns=5) -def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None: +async def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsBackend) -> None: """The websocket handlers await `WebSocket.receive()` to detect disconnection and stop the subscription. This test ensures that the subscription is only stopped in the case of receiving a `websocket.disconnect` message. @@ -140,7 +156,7 @@ def test_ws_route_handlers_receive_arbitrary_message(channels_backend: ChannelsB @pytest.mark.flaky(reruns=5) -async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None: +def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_backend: ChannelsBackend) -> None: channels_plugin = ChannelsPlugin( backend=channels_backend, arbitrary_channels_allowed=True, @@ -155,6 +171,8 @@ async def test_create_ws_route_handlers_arbitrary_channels_allowed(channels_back channels_plugin.publish("something", "foo") assert ws.receive_text(timeout=2) == "something" + time.sleep(0.1) + with client.websocket_connect("/ws/bar") as ws: channels_plugin.publish("something else", "bar") assert ws.receive_text(timeout=2) == "something else" @@ -217,13 +235,15 @@ async def test_start_subscription( @pytest.mark.parametrize("history", [1, 2]) @pytest.mark.parametrize("channels", [["foo"], ["foo", "bar"]]) async def test_subscribe_with_history( - async_mock: AsyncMock, memory_backend: MemoryChannelsBackend, channels: list[str], history: int + async_mock: AsyncMock, memory_backend_with_history: MemoryChannelsBackend, channels: list[str], history: int ) -> None: - async with ChannelsPlugin(backend=memory_backend, channels=channels) as plugin: + async with ChannelsPlugin(backend=memory_backend_with_history, channels=channels) as plugin: expected_messages = set() for channel in channels: - messages = await _populate_channels_backend(message_count=4, backend=memory_backend, channel=channel) + messages = await _populate_channels_backend( + message_count=4, backend=memory_backend_with_history, channel=channel + ) expected_messages.update(messages[-history:]) subscriber = await plugin.subscribe(channels, history=history) @@ -235,13 +255,15 @@ async def test_subscribe_with_history( @pytest.mark.parametrize("history", [1, 2]) @pytest.mark.parametrize("channels", [["foo"], ["foo", "bar"]]) async def test_start_subscription_with_history( - async_mock: AsyncMock, memory_backend: MemoryChannelsBackend, channels: list[str], history: int + async_mock: AsyncMock, memory_backend_with_history: MemoryChannelsBackend, channels: list[str], history: int ) -> None: - async with ChannelsPlugin(backend=memory_backend, channels=channels) as plugin: + async with ChannelsPlugin(backend=memory_backend_with_history, channels=channels) as plugin: expected_messages = set() for channel in channels: - messages = await _populate_channels_backend(message_count=4, backend=memory_backend, channel=channel) + messages = await _populate_channels_backend( + message_count=4, backend=memory_backend_with_history, channel=channel + ) expected_messages.update(messages[-history:]) async with plugin.start_subscription(channels, history=history) as subscriber: @@ -322,16 +344,15 @@ async def _populate_channels_backend(*, message_count: int, channel: str, backen ], ) async def test_handler_sends_history( - memory_backend: MemoryChannelsBackend, + memory_backend_with_history: MemoryChannelsBackend, message_count: int, handler_send_history: int, expected_history_count: int, mocker: MockerFixture, ) -> None: mock_socket_send = mocker.patch("litestar.connection.websocket.WebSocket.send_data") - memory_backend._max_history_length = 10 plugin = ChannelsPlugin( - backend=memory_backend, + backend=memory_backend_with_history, arbitrary_channels_allowed=True, ws_handler_send_history=handler_send_history, create_ws_route_handlers=True, @@ -339,8 +360,10 @@ async def test_handler_sends_history( app = Litestar([], plugins=[plugin]) with TestClient(app) as client: - await memory_backend.subscribe(["foo"]) - messages = await _populate_channels_backend(message_count=message_count, channel="foo", backend=memory_backend) + await memory_backend_with_history.subscribe(["foo"]) + messages = await _populate_channels_backend( + message_count=message_count, channel="foo", backend=memory_backend_with_history + ) with client.websocket_connect("/foo"): pass @@ -353,11 +376,11 @@ async def test_handler_sends_history( @pytest.mark.parametrize("channels,expected_entry_count", [("foo", 1), (["foo", "bar"], 2)]) async def test_set_subscriber_history( - channels: str | list[str], memory_backend: MemoryChannelsBackend, expected_entry_count: int + channels: str | list[str], memory_backend_with_history: MemoryChannelsBackend, expected_entry_count: int ) -> None: - async with ChannelsPlugin(backend=memory_backend, arbitrary_channels_allowed=True) as plugin: + async with ChannelsPlugin(backend=memory_backend_with_history, arbitrary_channels_allowed=True) as plugin: subscriber = await plugin.subscribe(channels) - await memory_backend.publish(b"something", channels if isinstance(channels, list) else [channels]) + await memory_backend_with_history.publish(b"something", channels if isinstance(channels, list) else [channels]) await plugin.put_subscriber_history(subscriber, channels) @@ -388,3 +411,13 @@ async def test_backlog( assert async_mock.call_count == 2 assert [call.args[0] for call in async_mock.call_args_list] == expected_messages + + +async def test_shutdown_idempotent(memory_backend: MemoryChannelsBackend) -> None: + # calling shutdown repeatedly or before startup shouldn't cause any issues + plugin = ChannelsPlugin(backend=memory_backend, arbitrary_channels_allowed=True) + await plugin._on_shutdown() + await plugin._on_startup() + + await plugin._on_shutdown() + await plugin._on_shutdown() diff --git a/tests/unit/test_channels/test_subscriber.py b/tests/unit/test_channels/test_subscriber.py index c9f3fbe188..2be3ebcb5d 100644 --- a/tests/unit/test_channels/test_subscriber.py +++ b/tests/unit/test_channels/test_subscriber.py @@ -64,6 +64,19 @@ async def test_stop(join: bool) -> None: assert subscriber._task is None +async def test_stop_with_task_done() -> None: + subscriber = Subscriber(AsyncMock()) + async with subscriber.run_in_background(AsyncMock()): + assert subscriber._task + assert subscriber.is_running + + subscriber.put_nowait(None) + + await subscriber.stop(join=True) + + assert subscriber._task is None + + @pytest.mark.parametrize("join", [False, True]) async def test_stop_no_task(join: bool) -> None: subscriber = Subscriber(AsyncMock()) @@ -96,3 +109,12 @@ async def test_backlog(backlog_strategy: BacklogStrategy) -> None: enqueued_items = await get_from_stream(subscriber, 2) assert expected_messages == enqueued_items + + +async def tests_run_in_background_run_in_background_called_while_running_raises() -> None: + subscriber = Subscriber(AsyncMock()) + + async with subscriber.run_in_background(AsyncMock()): + with pytest.raises(RuntimeError): + async with subscriber.run_in_background(AsyncMock()): + pass diff --git a/tests/unit/test_cli/__init__.py b/tests/unit/test_cli/__init__.py index 88482daf08..e1cf691155 100644 --- a/tests/unit/test_cli/__init__.py +++ b/tests/unit/test_cli/__init__.py @@ -76,3 +76,37 @@ def create_app() -> Litestar: return Litestar(route_handlers=[], plugins=[StartupPrintPlugin()]) """ +APP_FILE_CONTENT_ROUTES_EXAMPLE = """ +from litestar import Litestar, get +from litestar.openapi import OpenAPIConfig, OpenAPIController +from typing import Dict + +class CustomOpenAPIController(OpenAPIController): + path = "/api-docs" + + +@get("/") +def hello_world() -> Dict[str, str]: + return {"hello": "world"} + + +@get("/foo") +def foo() -> str: + return "bar" + + +@get("/schema/all/foo/bar/schema/") +def long_api() -> Dict[str, str]: + return {"test": "api"} + + + +app = Litestar( + openapi_config=OpenAPIConfig( + title="test_app", + version="0", + openapi_controller=CustomOpenAPIController), + route_handlers=[hello_world, foo, long_api] +) + +""" diff --git a/tests/unit/test_cli/test_core_commands.py b/tests/unit/test_cli/test_core_commands.py index 47ad066ba7..5c1deafa73 100644 --- a/tests/unit/test_cli/test_core_commands.py +++ b/tests/unit/test_cli/test_core_commands.py @@ -1,7 +1,8 @@ import os +import re import sys from pathlib import Path -from typing import Callable, Generator, List, Optional +from typing import Callable, Generator, List, Optional, Tuple from unittest.mock import MagicMock import pytest @@ -10,11 +11,13 @@ from pytest_mock import MockerFixture from litestar import __version__ as litestar_version +from litestar.cli._utils import remove_default_schema_routes, remove_routes_with_patterns from litestar.cli.main import litestar_group as cli_command from litestar.exceptions import LitestarWarning from . import ( APP_FACTORY_FILE_CONTENT_SERVER_LIFESPAN_PLUGIN, + APP_FILE_CONTENT_ROUTES_EXAMPLE, CREATE_APP_FILE_CONTENT, GENERIC_APP_FACTORY_FILE_CONTENT, GENERIC_APP_FACTORY_FILE_CONTENT_STRING_ANNOTATION, @@ -321,3 +324,99 @@ def test_run_command_with_server_lifespan_plugin( ssl_certfile=None, ssl_keyfile=None, ) + + +@pytest.mark.parametrize( + "app_content, schema_enabled, exclude_pattern_list", + [ + (APP_FILE_CONTENT_ROUTES_EXAMPLE, False, ()), + (APP_FILE_CONTENT_ROUTES_EXAMPLE, False, ("/foo", "/destroy/.*", "/java", "/haskell")), + (APP_FILE_CONTENT_ROUTES_EXAMPLE, True, ()), + (APP_FILE_CONTENT_ROUTES_EXAMPLE, True, ("/foo", "/destroy/.*", "/java", "/haskell")), + ], +) +@pytest.mark.xdist_group("cli_autodiscovery") +def test_routes_command_options( + runner: CliRunner, + app_content: str, + schema_enabled: bool, + exclude_pattern_list: Tuple[str, ...], + create_app_file: CreateAppFileFixture, +) -> None: + create_app_file("app.py", content=app_content) + + command = "routes" + if schema_enabled: + command += " --schema " + if exclude_pattern_list: + for pattern in exclude_pattern_list: + command += f" --exclude={pattern}" + + result = runner.invoke(cli_command, command) + assert result.exception is None + assert result.exit_code == 0 + + result_routes = re.findall(r"\/(?:\/|[a-z]|\.|-|[0-9])* \(HTTP\)", result.output) + for route in result_routes: + route_words = route.split(" ") + root_dir = route_words[0] + if not schema_enabled: + assert root_dir != "/api-docs" + + assert root_dir not in exclude_pattern_list + result_routes_len = len(result_routes) + if schema_enabled and exclude_pattern_list: + assert result_routes_len == 11 + elif schema_enabled and not exclude_pattern_list: + assert result_routes_len == 12 + elif not schema_enabled and exclude_pattern_list: + assert result_routes_len == 2 + elif not schema_enabled and not exclude_pattern_list: + assert result_routes_len == 3 + + +def test_remove_default_schema_routes() -> None: + routes = [ + "/", + "/schema", + "/schema/elements", + "/schema/oauth2-redirect.html", + "/schema/openapi.json", + "/schema/openapi.yaml", + "/schema/openapi.yml", + "/schema/rapidoc", + "/schema/redoc", + "/schema/swagger", + "/destroy/all/foo/bar/schema", + "/foo", + ] + http_routes = [] + for route in routes: + http_route = MagicMock() + http_route.path = route + http_routes.append(http_route) + + api_config = MagicMock() + api_config.openapi_controller.path = "/schema" + + results = remove_default_schema_routes(http_routes, api_config) # type: ignore + assert len(results) == 3 + for result in results: + words = re.split(r"(^\/[a-z]+)", result.path) + assert "/schema" not in words + + +def test_remove_routes_with_patterns() -> None: + routes = ["/", "/destroy/all/foo/bar/schema", "/foo"] + http_routes = [] + for route in routes: + http_route = MagicMock() + http_route.path = route + http_routes.append(http_route) + + patterns = ("/destroy", "/pizza", "[]") + results = remove_routes_with_patterns(http_routes, patterns) # type: ignore + paths = [route.path for route in results] + assert len(paths) == 2 + for route in ["/", "/foo"]: + assert route in paths diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py new file mode 100644 index 0000000000..040051ee2e --- /dev/null +++ b/tests/unit/test_concurrency.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock + +import trio +from pytest_mock import MockerFixture + +from litestar.concurrency import ( + get_asyncio_executor, + get_trio_capacity_limiter, + set_asyncio_executor, + set_trio_capacity_limiter, + sync_to_thread, +) + + +def func() -> int: + return 1 + + +def test_sync_to_thread_asyncio() -> None: + loop = asyncio.new_event_loop() + assert loop.run_until_complete(sync_to_thread(func)) == 1 + + +def test_sync_to_thread_trio() -> None: + assert trio.run(sync_to_thread, func) == 1 + + +def test_get_set_asyncio_executor() -> None: + assert get_asyncio_executor() is None + executor = ThreadPoolExecutor() + set_asyncio_executor(executor) + assert get_asyncio_executor() is executor + + +def test_get_set_trio_capacity_limiter() -> None: + limiter = trio.CapacityLimiter(10) + assert get_trio_capacity_limiter() is None + set_trio_capacity_limiter(limiter) + assert get_trio_capacity_limiter() is limiter + + +def test_asyncio_uses_executor(mocker: MockerFixture) -> None: + executor = ThreadPoolExecutor() + + mocker.patch("litestar.concurrency.get_asyncio_executor", return_value=executor) + mock_run_in_executor = AsyncMock() + mocker.patch("litestar.concurrency.asyncio.get_running_loop").return_value.run_in_executor = mock_run_in_executor + + loop = asyncio.new_event_loop() + loop.run_until_complete(sync_to_thread(func)) + + assert mock_run_in_executor.call_args_list[0].args[0] is executor + + +def test_trio_uses_limiter(mocker: MockerFixture) -> None: + limiter = trio.CapacityLimiter(10) + mocker.patch("litestar.concurrency.get_trio_capacity_limiter", return_value=limiter) + mock_run_sync = mocker.patch("trio.to_thread.run_sync", new_callable=AsyncMock) + + trio.run(sync_to_thread, func) + + assert mock_run_sync.call_args_list[0].kwargs["limiter"] is limiter diff --git a/tests/unit/test_middleware/test_middleware_handling.py b/tests/unit/test_middleware/test_middleware_handling.py index 98d6edc154..3f19ba837e 100644 --- a/tests/unit/test_middleware/test_middleware_handling.py +++ b/tests/unit/test_middleware/test_middleware_handling.py @@ -75,7 +75,7 @@ def handler() -> None: cur = client.app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0] while hasattr(cur, "app"): unpacked_middleware.append(cur) - cur = cast("ASGIApp", cur.app) + cur = cast("ASGIApp", cur.app) # pyright: ignore unpacked_middleware.append(cur) assert len(unpacked_middleware) == 4 diff --git a/tests/unit/test_openapi/test_security_schemes.py b/tests/unit/test_openapi/test_security_schemes.py index 6bdf62db57..f543494371 100644 --- a/tests/unit/test_openapi/test_security_schemes.py +++ b/tests/unit/test_openapi/test_security_schemes.py @@ -1,10 +1,10 @@ -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any import pytest from litestar import Controller, Litestar, Router, get from litestar.openapi.config import OpenAPIConfig -from litestar.openapi.spec import Components, SecurityRequirement +from litestar.openapi.spec import Components from litestar.openapi.spec.security_scheme import SecurityScheme if TYPE_CHECKING: @@ -99,7 +99,7 @@ def test_schema_with_route_security_overridden(protected_route: "HTTPRouteHandle def test_layered_security_declaration() -> None: class MyController(Controller): path = "/controller" - security: List[SecurityRequirement] = [{"controllerToken": []}] # pyright: ignore + security = [{"controllerToken": []}] # pyright: ignore @get("", security=[{"handlerToken": []}]) def my_handler(self) -> None: