From 729000b752b9a8166eee3544a30fb8a8ad100987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Kali=C5=84ski?= <47140412+KaQuMiQ@users.noreply.github.com> Date: Tue, 30 Apr 2024 12:38:46 +0200 Subject: [PATCH] Add tool direct result --- src/draive/__init__.py | 2 - src/draive/helpers/mimic.py | 81 ++++++++++------ src/draive/metrics/function.py | 34 ++++--- src/draive/mistral/chat_response.py | 16 ++-- src/draive/mistral/chat_tools.py | 137 +++++++++++++++++++--------- src/draive/openai/chat_response.py | 16 ++-- src/draive/openai/chat_stream.py | 22 +++-- src/draive/openai/chat_tools.py | 129 ++++++++++++++++++-------- src/draive/scope/access.py | 23 ++++- src/draive/tools/tool.py | 70 ++++++++++++-- src/draive/tools/toolbox.py | 11 ++- src/draive/utils/trace.py | 6 +- tests/test_cache.py | 58 +++++++----- 13 files changed, 425 insertions(+), 180 deletions(-) diff --git a/src/draive/__init__.py b/src/draive/__init__.py index b9bb042..ac7efb2 100644 --- a/src/draive/__init__.py +++ b/src/draive/__init__.py @@ -130,8 +130,6 @@ "AsyncStream", "AsyncStreamTask", "agent", - "agent", - "Agent", "Agent", "AgentException", "AgentFlow", diff --git a/src/draive/helpers/mimic.py b/src/draive/helpers/mimic.py index f977c00..14a5e0a 100644 --- a/src/draive/helpers/mimic.py +++ b/src/draive/helpers/mimic.py @@ -1,44 +1,67 @@ from collections.abc import Callable -from typing import Any, cast +from typing import Any, cast, overload __all__ = [ "mimic_function", ] +@overload def mimic_function[**Args, Result]( function: Callable[Args, Result], /, within: Callable[..., Any], -) -> Callable[Args, Result]: - # mimic function attributes if able - for attribute in [ - "__module__", - "__name__", - "__qualname__", - "__annotations__", - "__defaults__", - "__kwdefaults__", - "__doc__", - ]: - try: - setattr( - within, - attribute, - getattr( - function, +) -> Callable[Args, Result]: ... + + +@overload +def mimic_function[**Args, Result]( + function: Callable[Args, Result], + /, +) -> Callable[[Callable[..., Any]], Callable[Args, Result]]: ... + + +def mimic_function[**Args, Result]( + function: Callable[Args, Result], + /, + within: Callable[..., Result] | None = None, +) -> Callable[[Callable[..., Result]], Callable[Args, Result]] | Callable[Args, Result]: + def mimic( + target: Callable[..., Result], + ) -> Callable[Args, Result]: + # mimic function attributes if able + for attribute in [ + "__module__", + "__name__", + "__qualname__", + "__annotations__", + "__defaults__", + "__kwdefaults__", + "__doc__", + ]: + try: + setattr( + target, attribute, - ), - ) + getattr( + function, + attribute, + ), + ) + except AttributeError: + pass + try: + target.__dict__.update(function.__dict__) except AttributeError: pass - try: - within.__dict__.update(function.__dict__) - except AttributeError: - pass - - return cast( - Callable[Args, Result], - within, - ) + + return cast( + Callable[Args, Result], + target, + ) + + if target := within: + return mimic(target) + else: + return mimic diff --git a/src/draive/metrics/function.py b/src/draive/metrics/function.py index 56cd55d..16d82c6 100644 --- a/src/draive/metrics/function.py +++ b/src/draive/metrics/function.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Any, Self from draive.helpers import MISSING, Missing @@ -63,19 +64,26 @@ def __add__(self, other: Self) -> Self: # the code below does not keep proper exception semantics of BaseException/Exception # however we are using it only for logging purposes at the moment # because of that merging exceptions in groups is simplified under BaseExceptionGroup + exceptions: Sequence[BaseException] + exception_messages: Sequence[str] if isinstance(self.exception, BaseExceptionGroup): - return self.__class__( - name="ExceptionGroup", - exception=BaseExceptionGroup( - "Multiple errors", - (*self.exception.exceptions, other.exception), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - ), - ) + exceptions = [] + exception_messages = [] + for exception in (*self.exception.exceptions, other.exception): # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + exception_messages.append(exception) # pyright: ignore[reportArgumentType] + exceptions.append(f"{exception.__qualname__}:{exception}") # pyright: ignore[reportArgumentType, reportUnknownMemberType] + else: - return self.__class__( - name="ExceptionGroup", - exception=BaseExceptionGroup( - "Multiple errors", - (self.exception, other.exception), - ), + exceptions = (self.exception, other.exception) + exception_messages = ( + f"{self.exception.__qualname__}:{self.exception}", + f"{other.exception.__qualname__}:{other.exception}", ) + + return self.__class__( + name="ExceptionGroup", + exception=BaseExceptionGroup( + f"Multiple errors: [{','.join(exception_messages)}]", + exceptions, + ), + ) diff --git a/src/draive/mistral/chat_response.py b/src/draive/mistral/chat_response.py index e9622ab..b0a9eaa 100644 --- a/src/draive/mistral/chat_response.py +++ b/src/draive/mistral/chat_response.py @@ -66,14 +66,18 @@ async def _chat_response( completion_message: ChatMessageResponse = completion.choices[0].message if (tool_calls := completion_message.tool_calls) and (tools := tools): - messages.extend( - await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - ) ctx.record(ResultTrace.of(tool_calls)) + tools_result: list[ChatMessage] | str = await _execute_chat_tool_calls( + tool_calls=tool_calls, + tools=tools, + ) + + if isinstance(tools_result, str): + return tools_result + else: + messages.extend(tools_result) + elif message := completion_message.content: ctx.record(ResultTrace.of(message)) match message: diff --git a/src/draive/mistral/chat_tools.py b/src/draive/mistral/chat_tools.py index 3b085a5..5dcd044 100644 --- a/src/draive/mistral/chat_tools.py +++ b/src/draive/mistral/chat_tools.py @@ -1,7 +1,7 @@ import json from asyncio import gather from collections.abc import Awaitable -from typing import Any +from typing import Any, Literal, overload from draive.mistral.models import ChatMessage, ToolCallResponse from draive.tools import Toolbox @@ -15,65 +15,118 @@ async def _execute_chat_tool_calls( *, tool_calls: list[ToolCallResponse], tools: Toolbox, -) -> list[ChatMessage]: +) -> list[ChatMessage] | str: + direct_result: Awaitable[str] | None = None tool_call_results: list[Awaitable[ChatMessage]] = [] for call in tool_calls: - tool_call_results.append( - _execute_chat_tool_call( + # use only the first "direct result tool" requested, can't return more than one anyways + # despite of that all tools will be called to ensure that all desired actions were executed + if direct_result is None and tools.requires_direct_result(tool_name=call.function.name): + direct_result = _execute_chat_tool_call( call_id=call.id, name=call.function.name, arguments=call.function.arguments, tools=tools, + message_result=False, ) - ) - return [ - ChatMessage( - role="assistant", - content="", - tool_calls=[ - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments - if isinstance(call.function.arguments, str) - else json.dumps(call.function.arguments), - }, - } - for call in tool_calls - ], - ), - *await gather( + else: + tool_call_results.append( + _execute_chat_tool_call( + call_id=call.id, + name=call.function.name, + arguments=call.function.arguments, + tools=tools, + message_result=True, + ) + ) + if direct_result is not None: + results: tuple[str, ...] = await gather( + direct_result, *tool_call_results, return_exceptions=False, - ), - ] + ) + return results[0] # return only the requested direct result + else: + return [ + ChatMessage( + role="assistant", + content="", + tool_calls=[ + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments + if isinstance(call.function.arguments, str) + else json.dumps(call.function.arguments), + }, + } + for call in tool_calls + ], + ), + *await gather( + *tool_call_results, + return_exceptions=False, + ), + ] +@overload async def _execute_chat_tool_call( *, call_id: str, name: str, arguments: dict[str, Any] | str, tools: Toolbox, -) -> ChatMessage: + message_result: Literal[True], +) -> ChatMessage: ... + + +@overload +async def _execute_chat_tool_call( + *, + call_id: str, + name: str, + arguments: dict[str, Any] | str, + tools: Toolbox, + message_result: Literal[False], +) -> str: ... + + +async def _execute_chat_tool_call( + *, + call_id: str, + name: str, + arguments: dict[str, Any] | str, + tools: Toolbox, + message_result: bool, +) -> ChatMessage | str: try: # make sure that tool error won't blow up whole chain - result = await tools.call_tool( - name, - call_id=call_id, - arguments=arguments, - ) - return ChatMessage( - role="tool", - name=name, - content=str(result), + result: str = str( + await tools.call_tool( + name, + call_id=call_id, + arguments=arguments, + ) ) + if message_result: + return ChatMessage( + role="tool", + name=name, + content=str(result), + ) + else: + return result # error should be already logged by ScopeContext - except BaseException: - return ChatMessage( - role="tool", - name=name, - content="Error", - ) + except BaseException as exc: + if message_result: + return ChatMessage( + role="tool", + name=name, + content="Error", + ) + + else: # TODO: think about allowing the error chat message + raise exc diff --git a/src/draive/openai/chat_response.py b/src/draive/openai/chat_response.py index 8e484f0..d8ebda2 100644 --- a/src/draive/openai/chat_response.py +++ b/src/draive/openai/chat_response.py @@ -69,14 +69,18 @@ async def _chat_response( completion_message: ChatCompletionMessage = completion.choices[0].message if (tool_calls := completion_message.tool_calls) and (tools := tools): - messages.extend( - await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - ) ctx.record(ResultTrace.of(tool_calls)) + tools_result: list[ChatCompletionMessageParam] | str = await _execute_chat_tool_calls( + tool_calls=tool_calls, + tools=tools, + ) + + if isinstance(tools_result, str): + return tools_result + else: + messages.extend(tools_result) + elif message := completion_message.content: ctx.record(ResultTrace.of(message)) return message diff --git a/src/draive/openai/chat_stream.py b/src/draive/openai/chat_stream.py index fd7262e..3de9449 100644 --- a/src/draive/openai/chat_stream.py +++ b/src/draive/openai/chat_stream.py @@ -26,7 +26,7 @@ ] -async def _chat_stream( # noqa: PLR0913 +async def _chat_stream( # noqa: PLR0913, C901 *, client: OpenAIClient, config: OpenAIChatConfig, @@ -82,13 +82,21 @@ async def _chat_stream( # noqa: PLR0913 tool_calls=completion_head.tool_calls, completion_stream=completion_stream, ) - messages.extend( - await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - ) ctx.record(ResultTrace.of(tool_calls)) + + tools_result: ( + list[ChatCompletionMessageParam] | str + ) = await _execute_chat_tool_calls( + tool_calls=tool_calls, + tools=tools, + ) + + if isinstance(tools_result, str): + send_update(tools_result) + return tools_result + else: + messages.extend(tools_result) + break # after processing tool calls continue with recursion in outer context elif completion_head.content is not None: diff --git a/src/draive/openai/chat_tools.py b/src/draive/openai/chat_tools.py index 0ce446f..22691d8 100644 --- a/src/draive/openai/chat_tools.py +++ b/src/draive/openai/chat_tools.py @@ -1,6 +1,6 @@ from asyncio import gather from collections.abc import Awaitable -from typing import cast +from typing import Literal, cast, overload from openai import AsyncStream from openai.types.chat import ( @@ -25,39 +25,82 @@ async def _execute_chat_tool_calls( *, tool_calls: list[ChatCompletionMessageToolCall], tools: Toolbox, -) -> list[ChatCompletionMessageParam]: +) -> list[ChatCompletionMessageParam] | str: + direct_result: Awaitable[str] | None = None tool_call_params: list[ChatCompletionMessageToolCallParam] = [] tool_call_results: list[Awaitable[ChatCompletionMessageParam]] = [] for call in tool_calls: - tool_call_results.append( - _execute_chat_tool_call( + # use only the first "direct result tool" requested, can't return more than one anyways + # despite of that all tools will be called to ensure that all desired actions were executed + if direct_result is None and tools.requires_direct_result(tool_name=call.function.name): + direct_result = _execute_chat_tool_call( call_id=call.id, name=call.function.name, arguments=call.function.arguments, tools=tools, + message_result=False, ) - ) - tool_call_params.append( - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments, + else: + tool_call_results.append( + _execute_chat_tool_call( + call_id=call.id, + name=call.function.name, + arguments=call.function.arguments, + tools=tools, + message_result=True, + ), + ) + tool_call_params.append( + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + }, }, - } - ) + ) - return [ - { - "role": "assistant", - "tool_calls": tool_call_params, - }, - *await gather( + if direct_result is not None: + results: tuple[str, ...] = await gather( + direct_result, *tool_call_results, return_exceptions=False, - ), - ] + ) + return results[0] # return only the requested direct result + else: + return [ + { + "role": "assistant", + "tool_calls": tool_call_params, + }, + *await gather( + *tool_call_results, + return_exceptions=False, + ), + ] + + +@overload +async def _execute_chat_tool_call( + *, + call_id: str, + name: str, + arguments: str, + tools: Toolbox, + message_result: Literal[True], +) -> ChatCompletionMessageParam: ... + + +@overload +async def _execute_chat_tool_call( + *, + call_id: str, + name: str, + arguments: str, + tools: Toolbox, + message_result: Literal[False], +) -> str: ... async def _execute_chat_tool_call( @@ -66,26 +109,36 @@ async def _execute_chat_tool_call( name: str, arguments: str, tools: Toolbox, -) -> ChatCompletionMessageParam: + message_result: bool, +) -> ChatCompletionMessageParam | str: try: # make sure that tool error won't blow up whole chain - result = await tools.call_tool( - name, - call_id=call_id, - arguments=arguments, + result: str = str( + await tools.call_tool( + name, + call_id=call_id, + arguments=arguments, + ) ) - return { - "role": "tool", - "tool_call_id": call_id, - "content": str(result), - } + if message_result: + return { + "role": "tool", + "tool_call_id": call_id, + "content": str(result), + } + else: + return result # error should be already logged by ScopeContext - except Exception: - return { - "role": "tool", - "tool_call_id": call_id, - "content": "Error", - } + except Exception as exc: + if message_result: + return { + "role": "tool", + "tool_call_id": call_id, + "content": "Error", + } + + else: # TODO: think about allowing the error chat message + raise exc async def _flush_chat_tool_calls( # noqa: C901, PLR0912 diff --git a/src/draive/scope/access.py b/src/draive/scope/access.py index 3819900..4bd61a3 100644 --- a/src/draive/scope/access.py +++ b/src/draive/scope/access.py @@ -261,6 +261,7 @@ def wrapper( function: Callable[Args, Coroutine[None, None, Result]], /, ) -> Callable[Args, Coroutine[None, None, Result]]: + @mimic_function(function) async def wrapped(*args: Args.args, **kwargs: Args.kwargs) -> Result: async with ctx.new( label, @@ -272,7 +273,27 @@ async def wrapped(*args: Args.args, **kwargs: Args.kwargs) -> Result: ): return await function(*args, **kwargs) - return mimic_function(function, within=wrapped) + return wrapped + + return wrapper + + @staticmethod + def update[**Args, Result]( + *state: ParametrizedData, + ) -> Callable[ + [Callable[Args, Coroutine[None, None, Result]]], + Callable[Args, Coroutine[None, None, Result]], + ]: + def wrapper( + function: Callable[Args, Coroutine[None, None, Result]], + /, + ) -> Callable[Args, Coroutine[None, None, Result]]: + @mimic_function(function) + async def wrapped(*args: Args.args, **kwargs: Args.kwargs) -> Result: + with ctx.updated(*state): + return await function(*args, **kwargs) + + return wrapped return wrapper diff --git a/src/draive/tools/tool.py b/src/draive/tools/tool.py index d16efc3..a1b9e7a 100644 --- a/src/draive/tools/tool.py +++ b/src/draive/tools/tool.py @@ -27,7 +27,7 @@ def __call__(self) -> bool: ... @final class Tool[**Args, Result](ParametrizedTool[Args, Coroutine[None, None, Result]]): - def __init__( + def __init__( # noqa: PLR0913 self, /, name: str, @@ -35,12 +35,14 @@ def __init__( function: Function[Args, Coroutine[None, None, Result]], description: str | None = None, availability: ToolAvailability | None = None, + require_direct_result: bool = False, ) -> None: super().__init__( name=name, function=function, description=description, ) + self._require_direct_result: bool = require_direct_result self._availability: ToolAvailability = availability or ( lambda: True # available by default ) @@ -51,14 +53,18 @@ def __init__( def available(self) -> bool: return self._availability() + @property + def requires_direct_result(self) -> bool: + return self._require_direct_result + async def __call__( self, - tool_call_id: str | None = None, + call_id: str | None = None, *args: Args.args, **kwargs: Args.kwargs, ) -> Result: call_context: ToolCallContext = ToolCallContext( - call_id=tool_call_id or uuid4().hex, + call_id=call_id or uuid4().hex, tool=self.name, ) send_update: Callable[[ToolCallUpdate], None] = ctx.state( @@ -119,7 +125,23 @@ async def __call__( def tool[**Args, Result]( function: Function[Args, Coroutine[None, None, Result]], /, -) -> Tool[Args, Result]: ... +) -> Tool[Args, Result]: + """ + Convert a function to a tool using default parameters and no description. + + In order to adjust the arguments behavior and specification use an instance of Argument + as a default value of any given argument with desired configuration + for each argument individually. + + Parameters + ---------- + function: Function[Args, Coroutine[None, None, Result]] + a function to be wrapped as a Tool. + Returns + ------- + Tool[Args, Result] + a Tool representation of the provided function. + """ @overload @@ -128,7 +150,39 @@ def tool[**Args, Result]( name: str | None = None, description: str | None = None, availability: ToolAvailability | None = None, -) -> Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]]: ... + direct_result: bool = False, +) -> Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]]: + """ + Convert a function to a tool using provided parameters. + + In order to adjust the arguments behavior and specification use an instance of Argument + as a default value of any given argument with desired configuration + for each argument individually. + + Parameters + ---------- + name: str + name to be used in a tool specification. + Default is the name of the wrapped function. + description: int + description to be used in a tool specification. Allows to present the tool behavior to the + external system. + Default is empty. + availability: ToolAvailability + function used to verify availability of the tool in given context. It can be used to check + permissions or occurrence of a specific state to allow its usage. + Default is always available. + direct_result: bool + controls if tool result should break the ongoing processing and be the direct result of it. + Note that during concurrent execution of multiple tools the call/result order defines + direct result and exact behavior is not defined. + Default is False. + + Returns + ------- + Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]] + function allowing to convert other function to a Tool using provided configuration. + """ def tool[**Args, Result]( @@ -137,14 +191,11 @@ def tool[**Args, Result]( name: str | None = None, description: str | None = None, availability: ToolAvailability | None = None, + direct_result: bool = False, ) -> ( Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]] | Tool[Args, Result] ): - """ - Convert a function to a tool. Tool arguments support only limited types. - """ - def wrap( function: Function[Args, Coroutine[None, None, Result]], ) -> Tool[Args, Result]: @@ -153,6 +204,7 @@ def wrap( description=description, function=function, availability=availability, + require_direct_result=direct_result, ) if function := function: diff --git a/src/draive/tools/toolbox.py b/src/draive/tools/toolbox.py index 0ac9f49..6fe08c1 100644 --- a/src/draive/tools/toolbox.py +++ b/src/draive/tools/toolbox.py @@ -42,6 +42,15 @@ def suggested_tool(self) -> ToolSpecification | None: def available_tools(self) -> list[ToolSpecification]: return [tool.specification for tool in self._tools.values() if tool.available] + def requires_direct_result( + self, + tool_name: str, + ) -> bool: + if tool := self._tools.get(tool_name): + return tool.requires_direct_result + else: + return False + async def call_tool( self, name: str, @@ -49,7 +58,7 @@ async def call_tool( call_id: str, arguments: dict[str, Any] | str | bytes | None, ) -> Any: - if tool := self._tools[name]: + if tool := self._tools.get(name): return await tool( tool_call_id=call_id, **loads(arguments) if isinstance(arguments, str | bytes) else arguments or {}, diff --git a/src/draive/utils/trace.py b/src/draive/utils/trace.py index 4c65a17..304141b 100644 --- a/src/draive/utils/trace.py +++ b/src/draive/utils/trace.py @@ -30,6 +30,7 @@ def _traced_sync[**Args, Result]( ) -> Callable[Args, Result]: label: str = function.__name__ + @mimic_function(function) def wrapped( *args: Args.args, **kwargs: Args.kwargs, @@ -42,7 +43,7 @@ def wrapped( ctx.record(ResultTrace.of(result)) return result - return mimic_function(function, within=wrapped) + return wrapped def _traced_async[**Args, Result]( @@ -51,6 +52,7 @@ def _traced_async[**Args, Result]( ) -> Callable[Args, Coroutine[Any, Any, Result]]: label: str = function.__name__ + @mimic_function(function) async def wrapped( *args: Args.args, **kwargs: Args.kwargs, @@ -63,4 +65,4 @@ async def wrapped( ctx.record(ResultTrace.of(result)) return result - return mimic_function(function, within=wrapped) + return wrapped diff --git a/tests/test_cache.py b/tests/test_cache.py index 687039b..37299ef 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,47 +1,55 @@ from asyncio import CancelledError, Task, sleep -from random import randint +from collections.abc import Callable, Generator from time import sleep as sync_sleep from draive import cache -from pytest import mark, raises +from pytest import fixture, mark, raises class FakeException(Exception): pass -def test_returns_cached_value_with_same_argument(): +@fixture +def fake_random() -> Callable[[], Generator[int, None, None]]: + def random_next() -> Generator[int, None, None]: + yield from range(0, 65536) + + return random_next + + +def test_returns_cached_value_with_same_argument(fake_random: Callable[[], int]): @cache def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = randomized("expected") assert randomized("expected") == expected -def test_returns_fresh_value_with_different_argument(): +def test_returns_fresh_value_with_different_argument(fake_random: Callable[[], int]): @cache def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = randomized("expected") assert randomized("checked") != expected -def test_returns_fresh_value_with_limit_exceed(): +def test_returns_fresh_value_with_limit_exceed(fake_random: Callable[[], int]): @cache(limit=1) def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = randomized("expected") randomized("different") assert randomized("expected") != expected -def test_returns_same_value_with_repeating_argument(): +def test_returns_same_value_with_repeating_argument(fake_random: Callable[[], int]): @cache(limit=2) def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = randomized("expected") randomized("different") @@ -61,10 +69,10 @@ def randomized(_: str, /) -> int: randomized("expected") -def test_returns_fresh_value_with_expiration_time_exceed(): +def test_returns_fresh_value_with_expiration_time_exceed(fake_random: Callable[[], int]): @cache(expiration=0.02) def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = randomized("expected") sync_sleep(0.02) @@ -72,30 +80,30 @@ def randomized(_: str, /) -> int: @mark.asyncio -async def test_async_returns_cached_value_with_same_argument(): +async def test_async_returns_cached_value_with_same_argument(fake_random: Callable[[], int]): @cache async def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = await randomized("expected") assert await randomized("expected") == expected @mark.asyncio -async def test_async_returns_fresh_value_with_different_argument(): +async def test_async_returns_fresh_value_with_different_argument(fake_random: Callable[[], int]): @cache async def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = await randomized("expected") assert await randomized("checked") != expected @mark.asyncio -async def test_async_returns_fresh_value_with_limit_exceed(): +async def test_async_returns_fresh_value_with_limit_exceed(fake_random: Callable[[], int]): @cache(limit=1) async def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = await randomized("expected") await randomized("different") @@ -103,10 +111,10 @@ async def randomized(_: str, /) -> int: @mark.asyncio -async def test_async_returns_same_value_with_repeating_argument(): +async def test_async_returns_same_value_with_repeating_argument(fake_random: Callable[[], int]): @cache(limit=2) async def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = await randomized("expected") await randomized("different") @@ -118,10 +126,12 @@ async def randomized(_: str, /) -> int: @mark.asyncio -async def test_async_returns_fresh_value_with_expiration_time_exceed(): +async def test_async_returns_fresh_value_with_expiration_time_exceed( + fake_random: Callable[[], int], +): @cache(expiration=0.02) async def randomized(_: str, /) -> int: - return randint(-65536, 65535) + return fake_random() expected: int = await randomized("expected") await sleep(0.02) @@ -162,11 +172,11 @@ async def randomized(_: str, /) -> int: @mark.asyncio -async def test_async_expiration_creates_new_task(): +async def test_async_expiration_creates_new_task(fake_random: Callable[[], int]): @cache(expiration=0.01) async def randomized(_: str, /) -> int: await sleep(0.02) - return randint(-65536, 65535) + return fake_random() assert await randomized("expected") != await randomized("expected")