From 7abde741ee2d3db202589522e73f5f15392de149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Kali=C5=84ski?= Date: Mon, 6 May 2024 12:23:03 +0200 Subject: [PATCH] Allow extending api with kwargs --- constraints | 14 ++-- src/draive/embedding/call.py | 7 +- src/draive/embedding/embedder.py | 3 +- src/draive/generation/image/call.py | 4 ++ src/draive/generation/image/generator.py | 6 +- src/draive/generation/model/call.py | 8 ++- src/draive/generation/model/generator.py | 6 +- src/draive/generation/model/lmm.py | 12 ++-- src/draive/generation/text/call.py | 3 + src/draive/generation/text/generator.py | 6 +- src/draive/generation/text/lmm.py | 4 ++ src/draive/lmm/call.py | 13 +++- src/draive/lmm/completion.py | 17 +++-- src/draive/mistral/chat_response.py | 63 +++++++++++------ src/draive/mistral/client.py | 20 ++---- src/draive/mistral/embedding.py | 4 +- src/draive/mistral/lmm.py | 18 +++-- src/draive/openai/chat_response.py | 64 ++++++++++++----- src/draive/openai/chat_stream.py | 60 ++++++++++------ src/draive/openai/client.py | 87 +++++++++++++----------- src/draive/openai/embedding.py | 4 +- src/draive/openai/images.py | 7 +- src/draive/openai/lmm.py | 13 ++-- src/draive/parameters/data.py | 7 +- src/draive/tools/toolbox.py | 23 +++++-- src/draive/types/model.py | 1 + tests/test_cache.py | 4 +- 27 files changed, 312 insertions(+), 166 deletions(-) diff --git a/constraints b/constraints index 1b8e2f0..2001f0f 100644 --- a/constraints +++ b/constraints @@ -14,7 +14,7 @@ certifi==2024.2.2 # requests charset-normalizer==3.3.2 # via requests -coverage==7.5.0 +coverage==7.5.1 # via pytest-cov distro==1.9.0 # via openai @@ -38,7 +38,7 @@ mdurl==0.1.2 nodeenv==1.8.0 # via pyright numpy==1.26.4 -openai==1.23.6 +openai==1.25.2 packaging==24.0 # via pytest pbr==6.0.0 @@ -49,9 +49,9 @@ pydantic==2.7.1 # via openai pydantic-core==2.18.2 # via pydantic -pygments==2.17.2 +pygments==2.18.0 # via rich -pyright==1.1.360 +pyright==1.1.361 pytest==7.4.4 # via # pytest-asyncio @@ -60,13 +60,13 @@ pytest-asyncio==0.23.6 pytest-cov==4.1.0 pyyaml==6.0.1 # via bandit -regex==2024.4.16 +regex==2024.4.28 # via tiktoken requests==2.31.0 # via tiktoken rich==13.7.1 # via bandit -ruff==0.4.2 +ruff==0.4.3 setuptools==69.5.1 # via nodeenv sniffio==1.3.1 @@ -77,7 +77,7 @@ sniffio==1.3.1 stevedore==5.2.0 # via bandit tiktoken==0.6.0 -tqdm==4.66.2 +tqdm==4.66.4 # via openai typing-extensions==4.11.0 # via diff --git a/src/draive/embedding/call.py b/src/draive/embedding/call.py index 8825c22..8d2f425 100644 --- a/src/draive/embedding/call.py +++ b/src/draive/embedding/call.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.embedding.embedded import Embedded from draive.embedding.state import Embedding @@ -11,5 +12,9 @@ async def embed_text( values: Iterable[str], + **extra: Any, ) -> list[Embedded[str]]: - return await ctx.state(Embedding).embed_text(values=values) + return await ctx.state(Embedding).embed_text( + values=values, + **extra, + ) diff --git a/src/draive/embedding/embedder.py b/src/draive/embedding/embedder.py index 8f41189..67e5cf3 100644 --- a/src/draive/embedding/embedder.py +++ b/src/draive/embedding/embedder.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from draive.embedding.embedded import Embedded @@ -13,4 +13,5 @@ class Embedder[Value](Protocol): async def __call__( self, values: Iterable[Value], + **extra: Any, ) -> list[Embedded[Value]]: ... diff --git a/src/draive/generation/image/call.py b/src/draive/generation/image/call.py index 62bc539..7e7614c 100644 --- a/src/draive/generation/image/call.py +++ b/src/draive/generation/image/call.py @@ -1,3 +1,5 @@ +from typing import Any + from draive.generation.image.state import ImageGeneration from draive.scope import ctx from draive.types import ImageContent @@ -10,7 +12,9 @@ async def generate_image( *, instruction: str, + **extra: Any, ) -> ImageContent: return await ctx.state(ImageGeneration).generate( instruction=instruction, + **extra, ) diff --git a/src/draive/generation/image/generator.py b/src/draive/generation/image/generator.py index 6f723f7..0610238 100644 --- a/src/draive/generation/image/generator.py +++ b/src/draive/generation/image/generator.py @@ -1,4 +1,4 @@ -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from draive.types import ImageContent @@ -13,5 +13,5 @@ async def __call__( self, *, instruction: str, - ) -> ImageContent: - ... + **extra: Any, + ) -> ImageContent: ... diff --git a/src/draive/generation/model/call.py b/src/draive/generation/model/call.py index 2e3d5ec..8a1890f 100644 --- a/src/draive/generation/model/call.py +++ b/src/draive/generation/model/call.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.generation.model.state import ModelGeneration from draive.scope import ctx @@ -11,18 +12,21 @@ async def generate_model[Generated: Model]( - model: type[Generated], + generated: type[Generated], + /, *, instruction: str, input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + **extra: Any, ) -> Generated: model_generation: ModelGeneration = ctx.state(ModelGeneration) return await model_generation.generate( - model, + generated, instruction=instruction, input=input, tools=tools or model_generation.tools, examples=examples, + **extra, ) diff --git a/src/draive/generation/model/generator.py b/src/draive/generation/model/generator.py index 0aab0d6..c94881c 100644 --- a/src/draive/generation/model/generator.py +++ b/src/draive/generation/model/generator.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from draive.tools import Toolbox from draive.types import Model, MultimodalContent @@ -13,10 +13,12 @@ class ModelGenerator(Protocol): async def __call__[Generated: Model]( # noqa: PLR0913 self, - model: type[Generated], + generated: type[Generated], + /, *, instruction: str, input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + **extra: Any, ) -> Generated: ... diff --git a/src/draive/generation/model/lmm.py b/src/draive/generation/model/lmm.py index 9cc37d7..905bc04 100644 --- a/src/draive/generation/model/lmm.py +++ b/src/draive/generation/model/lmm.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.lmm import LMMCompletionMessage, lmm_completion from draive.tools import Toolbox @@ -10,18 +11,20 @@ async def lmm_generate_model[Generated: Model]( - model: type[Generated], + generated: type[Generated], + /, *, instruction: str, input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + **extra: Any, ) -> Generated: system_message: LMMCompletionMessage = LMMCompletionMessage( role="system", content=INSTRUCTION.format( instruction=instruction, - format=model.specification(), + format=generated.specification(), ), ) input_message: LMMCompletionMessage = LMMCompletionMessage( @@ -61,10 +64,11 @@ async def lmm_generate_model[Generated: Model]( context=context, tools=tools, output="json", + stream=False, + **extra, ) - generated: Generated = model.from_json(completion.content_string) - return generated + return generated.from_json(completion.content_string) INSTRUCTION: str = """\ diff --git a/src/draive/generation/text/call.py b/src/draive/generation/text/call.py index 8d14072..75ab8c1 100644 --- a/src/draive/generation/text/call.py +++ b/src/draive/generation/text/call.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.generation.text.state import TextGeneration from draive.scope import ctx @@ -16,6 +17,7 @@ async def generate_text( input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, str]] | None = None, + **extra: Any, ) -> str: text_generation: TextGeneration = ctx.state(TextGeneration) return await text_generation.generate( @@ -23,4 +25,5 @@ async def generate_text( input=input, tools=tools or text_generation.tools, examples=examples, + **extra, ) diff --git a/src/draive/generation/text/generator.py b/src/draive/generation/text/generator.py index c28abb6..91a9bec 100644 --- a/src/draive/generation/text/generator.py +++ b/src/draive/generation/text/generator.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from draive.tools import Toolbox from draive.types import MultimodalContent @@ -18,5 +18,5 @@ async def __call__( input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, str]] | None = None, - ) -> str: - ... + **extra: Any, + ) -> str: ... diff --git a/src/draive/generation/text/lmm.py b/src/draive/generation/text/lmm.py index 496f99e..00b1c56 100644 --- a/src/draive/generation/text/lmm.py +++ b/src/draive/generation/text/lmm.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.lmm import LMMCompletionMessage, lmm_completion from draive.tools import Toolbox @@ -15,6 +16,7 @@ async def lmm_generate_text( input: MultimodalContent, # noqa: A002 tools: Toolbox | None = None, examples: Iterable[tuple[MultimodalContent, str]] | None = None, + **extra: Any, ) -> str: system_message: LMMCompletionMessage = LMMCompletionMessage( role="system", @@ -57,6 +59,8 @@ async def lmm_generate_text( context=context, tools=tools, output="text", + stream=False, + **extra, ) generated: str = completion.content_string diff --git a/src/draive/lmm/call.py b/src/draive/lmm/call.py index 4ddc57d..85dd8ad 100644 --- a/src/draive/lmm/call.py +++ b/src/draive/lmm/call.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Literal, overload +from typing import Any, Literal, overload from draive.lmm.completion import LMMCompletionStream from draive.lmm.message import ( @@ -21,6 +21,7 @@ async def lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Literal[True], + **extra: Any, ) -> LMMCompletionStream: ... @@ -30,6 +31,7 @@ async def lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Callable[[LMMCompletionStreamingUpdate], None], + **extra: Any, ) -> LMMCompletionMessage: ... @@ -39,6 +41,8 @@ async def lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, output: Literal["text", "json"] = "text", + stream: Literal[False] = False, + **extra: Any, ) -> LMMCompletionMessage: ... @@ -48,6 +52,7 @@ async def lmm_completion( tools: Toolbox | None = None, output: Literal["text", "json"] = "text", stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False, + **extra: Any, ) -> LMMCompletionStream | LMMCompletionMessage: match stream: case False: @@ -55,16 +60,22 @@ async def lmm_completion( context=context, tools=tools, output=output, + stream=False, + **extra, ) case True: return await ctx.state(LMM).completion( context=context, tools=tools, + output=output, stream=True, + **extra, ) case progress: return await ctx.state(LMM).completion( context=context, tools=tools, + output=output, stream=progress, + **extra, ) diff --git a/src/draive/lmm/completion.py b/src/draive/lmm/completion.py index de7cb8a..a91c205 100644 --- a/src/draive/lmm/completion.py +++ b/src/draive/lmm/completion.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Literal, Protocol, Self, overload, runtime_checkable +from typing import Any, Literal, Protocol, Self, overload, runtime_checkable from draive.lmm.message import ( LMMCompletionMessage, @@ -26,8 +26,10 @@ async def __call__( self, *, context: list[LMMCompletionMessage], - tools: Toolbox | None = None, + tools: Toolbox | None, + output: Literal["text", "json"], stream: Literal[True], + **extra: Any, ) -> LMMCompletionStream: ... @overload @@ -35,8 +37,10 @@ async def __call__( self, *, context: list[LMMCompletionMessage], - tools: Toolbox | None = None, + tools: Toolbox | None, + output: Literal["text", "json"], stream: Callable[[LMMCompletionStreamingUpdate], None], + **extra: Any, ) -> LMMCompletionMessage: ... @overload @@ -44,8 +48,10 @@ async def __call__( self, *, context: list[LMMCompletionMessage], - tools: Toolbox | None = None, - output: Literal["text", "json"] = "text", + tools: Toolbox | None, + output: Literal["text", "json"], + stream: Literal[False], + **extra: Any, ) -> LMMCompletionMessage: ... async def __call__( @@ -55,4 +61,5 @@ async def __call__( tools: Toolbox | None = None, output: Literal["text", "json"] = "text", stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False, + **extra: Any, ) -> LMMCompletionStream | LMMCompletionMessage: ... diff --git a/src/draive/mistral/chat_response.py b/src/draive/mistral/chat_response.py index b0a9eaa..a5345ad 100644 --- a/src/draive/mistral/chat_response.py +++ b/src/draive/mistral/chat_response.py @@ -16,7 +16,7 @@ ] -async def _chat_response( +async def _chat_response( # noqa: C901 *, client: MistralClient, config: MistralChatConfig, @@ -24,32 +24,50 @@ async def _chat_response( tools: Toolbox, recursion_level: int = 0, ) -> str: - if recursion_level > config.recursion_limit: - raise MistralException("Reached limit of recursive calls of %d", config.recursion_limit) - with ctx.nested( "chat_response", metrics=[ArgumentsTrace.of(messages=messages.copy())], ): - suggest_tools: bool - available_tools: list[dict[str, object]] - if recursion_level == 0 and (suggested := tools.suggested_tool): - # suggest/require tool call only initially - suggest_tools = True - available_tools = cast(list[dict[str, object]], [suggested]) - else: - suggest_tools = False - available_tools = cast( - list[dict[str, object]], - tools.available_tools if tools else [], + completion: ChatCompletionResponse + + if recursion_level == config.recursion_limit: + ctx.log_warning("Reaching limit of recursive Mistral calls, ignoring tools...") + completion = await client.chat_completion( + config=config, + messages=messages, ) - completion: ChatCompletionResponse = await client.chat_completion( - config=config, - messages=messages, - tools=available_tools, - suggest_tools=suggest_tools, - ) + elif recursion_level != 0: # suggest/require tool call only initially + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[dict[str, object]], + tools.available_tools, + ), + ) + + elif suggested_tool := tools.suggested_tool: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[dict[str, object]], + [suggested_tool], + ), + suggest_tools=True, + ) + + else: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[dict[str, object]], + tools.available_tools, + ), + suggest_tools=tools.suggest_tools, + ) if usage := completion.usage: ctx.record( @@ -93,6 +111,9 @@ async def _chat_response( raise MistralException("Invalid Mistral completion", completion) # recursion outside of context + if recursion_level >= config.recursion_limit: + raise MistralException("Reached limit of recursive calls of %d", config.recursion_limit) + return await _chat_response( client=client, config=config, diff --git a/src/draive/mistral/client.py b/src/draive/mistral/client.py index 31a56bc..c971a4c 100644 --- a/src/draive/mistral/client.py +++ b/src/draive/mistral/client.py @@ -53,18 +53,8 @@ async def chat_completion( *, config: MistralChatConfig, messages: list[ChatMessage], - tools: list[dict[str, object]], - stream: Literal[True], - ) -> AsyncIterable[ChatCompletionStreamResponse]: ... - - @overload - async def chat_completion( - self, - *, - config: MistralChatConfig, - messages: list[ChatMessage], - tools: list[dict[str, object]], - suggest_tools: bool, + tools: list[dict[str, object]] | None = None, + suggest_tools: bool = False, stream: Literal[True], ) -> AsyncIterable[ChatCompletionStreamResponse]: ... @@ -74,7 +64,7 @@ async def chat_completion( *, config: MistralChatConfig, messages: list[ChatMessage], - tools: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, suggest_tools: bool = False, ) -> ChatCompletionResponse: ... @@ -83,7 +73,7 @@ async def chat_completion( # noqa: PLR0913 *, config: MistralChatConfig, messages: list[ChatMessage], - tools: list[dict[str, object]], + tools: list[dict[str, object]] | None = None, suggest_tools: bool = False, stream: bool = False, ) -> AsyncIterable[ChatCompletionStreamResponse] | ChatCompletionResponse: @@ -131,7 +121,7 @@ async def _create_chat_completion( # noqa: PLR0913 max_tokens: int | None, response_format: dict[str, str] | None, messages: list[ChatMessage], - tools: list[dict[str, object]], + tools: list[dict[str, object]] | None, tool_choice: str | None, ) -> ChatCompletionResponse: request_body: dict[str, Any] = { diff --git a/src/draive/mistral/embedding.py b/src/draive/mistral/embedding.py index 2d17475..352b250 100644 --- a/src/draive/mistral/embedding.py +++ b/src/draive/mistral/embedding.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.embedding import Embedded from draive.mistral.client import MistralClient @@ -12,8 +13,9 @@ async def mistral_embed_text( values: Iterable[str], + **extra: Any, ) -> list[Embedded[str]]: - config: MistralEmbeddingConfig = ctx.state(MistralEmbeddingConfig) + config: MistralEmbeddingConfig = ctx.state(MistralEmbeddingConfig).updated(**extra) with ctx.nested("text_embedding", metrics=[config]): results: list[list[float]] = await ctx.dependency(MistralClient).embedding( config=config, diff --git a/src/draive/mistral/lmm.py b/src/draive/mistral/lmm.py index 0d79277..6c250c3 100644 --- a/src/draive/mistral/lmm.py +++ b/src/draive/mistral/lmm.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Literal, overload +from typing import Any, Literal, overload from draive.lmm import LMMCompletionMessage, LMMCompletionStream, LMMCompletionStreamingUpdate from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] @@ -23,6 +23,7 @@ async def mistral_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Literal[True], + **extra: Any, ) -> LMMCompletionStream: ... @@ -32,6 +33,7 @@ async def mistral_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Callable[[LMMCompletionStreamingUpdate], None], + **extra: Any, ) -> LMMCompletionMessage: ... @@ -41,6 +43,8 @@ async def mistral_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, output: Literal["text", "json"] = "text", + stream: Literal[False] = False, + **extra: Any, ) -> LMMCompletionMessage: ... @@ -50,23 +54,23 @@ async def mistral_lmm_completion( tools: Toolbox | None = None, output: Literal["text", "json"] = "text", stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False, + **extra: Any, ) -> LMMCompletionStream | LMMCompletionMessage: client: MistralClient = ctx.dependency(MistralClient) - config: MistralChatConfig + config: MistralChatConfig = ctx.state(MistralChatConfig).updated(**extra) match output: case "text": - config = ctx.state(MistralChatConfig).updated(response_format={"type": "text"}) + config = config.updated(response_format={"type": "text"}) case "json": if tools is None: - config = ctx.state(MistralChatConfig).updated( - response_format={"type": "json_object"} - ) + config = config.updated(response_format={"type": "json_object"}) else: ctx.log_warning( "Attempting to use Mistral in JSON mode with tools which is not supported." " Using text mode instead..." ) - config = ctx.state(MistralChatConfig).updated(response_format={"type": "text"}) + config = config.updated(response_format={"type": "text"}) + messages: list[ChatMessage] = [_convert_message(message=message) for message in context] match stream: diff --git a/src/draive/openai/chat_response.py b/src/draive/openai/chat_response.py index d8ebda2..2ffdd8c 100644 --- a/src/draive/openai/chat_response.py +++ b/src/draive/openai/chat_response.py @@ -30,29 +30,54 @@ async def _chat_response( tools: Toolbox, recursion_level: int = 0, ) -> str: - if recursion_level > config.recursion_limit: - raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) - with ctx.nested( "chat_response", metrics=[ArgumentsTrace.of(messages=messages.copy())], ): - completion: ChatCompletion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools if tools else [], - ), - suggested_tool={ - "type": "function", - "function": { - "name": tools.suggested_tool_name, + completion: ChatCompletion + if recursion_level == config.recursion_limit: + ctx.log_warning("Reaching limit of recursive OpenAI calls, ignoring tools...") + completion = await client.chat_completion( + config=config, + messages=messages, + ) + + elif recursion_level != 0: # suggest/require tool call only initially + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools.available_tools, + ), + ) + + elif suggested_tool_name := tools.suggested_tool_name: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools.available_tools, + ), + tools_suggestion={ + "type": "function", + "function": { + "name": suggested_tool_name, + }, }, - } # suggest/require tool call only initially - if recursion_level == 0 and tools.suggested_tool_name - else None, - ) + ) + + else: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools.available_tools, + ), + tools_suggestion=tools.suggest_tools, + ) if usage := completion.usage: ctx.record( @@ -89,6 +114,9 @@ async def _chat_response( raise OpenAIException("Invalid OpenAI completion", completion) # recursion outside of context + if recursion_level >= config.recursion_limit: + raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) + return await _chat_response( client=client, config=config, diff --git a/src/draive/openai/chat_stream.py b/src/draive/openai/chat_stream.py index 3de9449..51ffb14 100644 --- a/src/draive/openai/chat_stream.py +++ b/src/draive/openai/chat_stream.py @@ -6,6 +6,7 @@ ChatCompletionChunk, ChatCompletionMessageParam, ChatCompletionMessageToolCall, + ChatCompletionNamedToolChoiceParam, ChatCompletionToolParam, ) from openai.types.chat.chat_completion_chunk import ChoiceDelta @@ -35,30 +36,46 @@ async def _chat_stream( # noqa: PLR0913, C901 send_update: Callable[[ToolCallUpdate | str], None], recursion_level: int = 0, ) -> str: - if recursion_level > config.recursion_limit: - raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) - with ctx.nested( "chat_stream", metrics=[ArgumentsTrace.of(messages=messages.copy())], ): - completion_stream: OpenAIAsyncStream[ChatCompletionChunk] = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools if tools else [], - ), - suggested_tool={ - "type": "function", - "function": { - "name": tools.suggested_tool_name, - }, - } # suggest/require tool call only initially - if recursion_level == 0 and tools.suggested_tool_name - else None, - stream=True, - ) + completion_stream: OpenAIAsyncStream[ChatCompletionChunk] + if recursion_level == config.recursion_limit: + ctx.log_warning("Reaching limit of recursive OpenAI calls, ignoring tools...") + completion_stream = await client.chat_completion( + config=config, + messages=messages, + tools=None, + stream=True, + ) + + else: + tools_suggestion: ChatCompletionNamedToolChoiceParam | bool + if recursion_level != 0: # suggest/require tool call only initially + tools_suggestion = False + + elif suggested_tool_name := tools.suggested_tool_name: + tools_suggestion = { + "type": "function", + "function": { + "name": suggested_tool_name, + }, + } + + else: + tools_suggestion = tools.suggest_tools + + completion_stream = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools.available_tools if tools else [], + ), + tools_suggestion=tools_suggestion, + stream=True, + ) while True: # load chunks to decide what to do next head: ChatCompletionChunk @@ -119,6 +136,9 @@ async def _chat_stream( # noqa: PLR0913, C901 continue # iterate over the stream until can decide what to do or reach the end # recursion outside of context + if recursion_level >= config.recursion_limit: + raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) + return await _chat_stream( client=client, config=config, diff --git a/src/draive/openai/client.py b/src/draive/openai/client.py index 64af9ec..195d81b 100644 --- a/src/draive/openai/client.py +++ b/src/draive/openai/client.py @@ -12,6 +12,7 @@ ChatCompletionChunk, ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ) from openai.types.create_embedding_response import CreateEmbeddingResponse @@ -77,7 +78,8 @@ async def chat_completion( *, config: OpenAIChatConfig, messages: list[ChatCompletionMessageParam], - tools: list[ChatCompletionToolParam], + tools: list[ChatCompletionToolParam] | None = None, + tools_suggestion: ChatCompletionNamedToolChoiceParam | bool = False, stream: Literal[True], ) -> AsyncStream[ChatCompletionChunk]: ... @@ -87,19 +89,8 @@ async def chat_completion( *, config: OpenAIChatConfig, messages: list[ChatCompletionMessageParam], - tools: list[ChatCompletionToolParam], - suggested_tool: ChatCompletionNamedToolChoiceParam | None, - stream: Literal[True], - ) -> AsyncStream[ChatCompletionChunk]: ... - - @overload - async def chat_completion( - self, - *, - config: OpenAIChatConfig, - messages: list[ChatCompletionMessageParam], - tools: list[ChatCompletionToolParam], - suggested_tool: ChatCompletionNamedToolChoiceParam | None = None, + tools: list[ChatCompletionToolParam] | None = None, + tools_suggestion: ChatCompletionNamedToolChoiceParam | bool = False, ) -> ChatCompletion: ... async def chat_completion( # noqa: PLR0913 @@ -107,24 +98,41 @@ async def chat_completion( # noqa: PLR0913 *, config: OpenAIChatConfig, messages: list[ChatCompletionMessageParam], - tools: list[ChatCompletionToolParam], - suggested_tool: ChatCompletionNamedToolChoiceParam | None = None, + tools: list[ChatCompletionToolParam] | None = None, + tools_suggestion: ChatCompletionNamedToolChoiceParam | bool = False, stream: bool = False, ) -> AsyncStream[ChatCompletionChunk] | ChatCompletion: + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven + match tools_suggestion: + case False: + tool_choice = "auto" if tools else NOT_GIVEN + + case True: + assert tools, "Can't require tools use without tools" # nosec: B101 + tool_choice = "required" + + case tool: + assert tools, "Can't require tools use without tools" # nosec: B101 + tool_choice = tool + return await self._client.chat.completions.create( messages=messages, model=config.model, - frequency_penalty=_value_when_given(config.frequency_penalty), - max_tokens=_value_when_given(config.max_tokens), + frequency_penalty=config.frequency_penalty + if config.frequency_penalty is not None + else NOT_GIVEN, + max_tokens=config.max_tokens if config.max_tokens is not None else NOT_GIVEN, n=1, - response_format=_value_when_given(config.response_format), - seed=_value_when_given(config.seed), + response_format=config.response_format + if config.response_format is not None + else NOT_GIVEN, + seed=config.seed if config.seed is not None else NOT_GIVEN, stream=stream, temperature=config.temperature, tools=tools or NOT_GIVEN, - tool_choice=(suggested_tool or "auto") if tools else NOT_GIVEN, - top_p=_value_when_given(config.top_p), - timeout=_value_when_given(config.timeout), + tool_choice=tool_choice, + top_p=config.top_p if config.top_p is not None else NOT_GIVEN, + timeout=config.timeout if config.timeout is not None else NOT_GIVEN, ) async def embedding( @@ -140,9 +148,13 @@ async def embedding( self._create_text_embedding( texts=list(inputs_list[index : index + config.batch_size]), model=config.model, - dimensions=config.dimensions, - encoding_format=config.encoding_format, - timeout=config.timeout, + dimensions=config.dimensions + if config.dimensions is not None + else NOT_GIVEN, + encoding_format=config.encoding_format + if config.encoding_format is not None + else NOT_GIVEN, + timeout=config.timeout if config.timeout is not None else NOT_GIVEN, ) for index in range(0, len(inputs_list), config.batch_size) ] @@ -154,17 +166,17 @@ async def _create_text_embedding( # noqa: PLR0913 self, texts: list[str], model: str, - dimensions: int | None, - encoding_format: Literal["float", "base64"] | None, - timeout: float | None, + dimensions: int | NotGiven, + encoding_format: Literal["float", "base64"] | NotGiven, + timeout: float | NotGiven, ) -> list[list[float]]: try: response: CreateEmbeddingResponse = await self._client.embeddings.create( input=texts, model=model, - dimensions=_value_when_given(dimensions), - encoding_format=_value_when_given(encoding_format), - timeout=_value_when_given(timeout), + dimensions=dimensions, + encoding_format=encoding_format, + timeout=timeout, ) return [element.embedding for element in response.data] @@ -200,17 +212,10 @@ async def generate_image( quality=config.quality, size=config.size, style=config.style, - timeout=_value_when_given(config.timeout), + timeout=config.timeout if config.timeout is not None else NOT_GIVEN, response_format=config.response_format, ) return response.data[0] - async def dispose(self): + async def dispose(self) -> None: await self._client.close() - - -def _value_when_given[Value]( - value: Value | None, - /, -) -> Value | NotGiven: - return value if value is not None else NOT_GIVEN diff --git a/src/draive/openai/embedding.py b/src/draive/openai/embedding.py index 81a8311..13f3712 100644 --- a/src/draive/openai/embedding.py +++ b/src/draive/openai/embedding.py @@ -1,4 +1,5 @@ from collections.abc import Iterable +from typing import Any from draive.embedding import Embedded from draive.openai.client import OpenAIClient @@ -12,8 +13,9 @@ async def openai_embed_text( values: Iterable[str], + **extra: Any, ) -> list[Embedded[str]]: - config: OpenAIEmbeddingConfig = ctx.state(OpenAIEmbeddingConfig) + config: OpenAIEmbeddingConfig = ctx.state(OpenAIEmbeddingConfig).updated(**extra) with ctx.nested("text_embedding", metrics=[config]): results: list[list[float]] = await ctx.dependency(OpenAIClient).embedding( config=config, diff --git a/src/draive/openai/images.py b/src/draive/openai/images.py index d1fd849..f048b82 100644 --- a/src/draive/openai/images.py +++ b/src/draive/openai/images.py @@ -1,3 +1,5 @@ +from typing import Any + from openai.types.image import Image from draive.openai.client import OpenAIClient @@ -14,9 +16,10 @@ async def openai_generate_image( *, instruction: str, + **extra: Any, ) -> ImageContent: client: OpenAIClient = ctx.dependency(OpenAIClient) - config: OpenAIImageGenerationConfig = ctx.state(OpenAIImageGenerationConfig) + config: OpenAIImageGenerationConfig = ctx.state(OpenAIImageGenerationConfig).updated(**extra) with ctx.nested("openai_generate_image", metrics=[config]): image: Image = await client.generate_image( config=config, @@ -24,7 +27,9 @@ async def openai_generate_image( ) if url := image.url: return ImageURLContent(image_url=url) + elif b64data := image.b64_json: return ImageBase64Content(image_base64=b64data) + else: raise OpenAIException("Invalid OpenAI response - missing image content") diff --git a/src/draive/openai/lmm.py b/src/draive/openai/lmm.py index 2444347..9253ec9 100644 --- a/src/draive/openai/lmm.py +++ b/src/draive/openai/lmm.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Literal, overload +from typing import Any, Literal, overload from openai.types.chat import ChatCompletionContentPartParam, ChatCompletionMessageParam @@ -24,6 +24,7 @@ async def openai_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Literal[True], + **extra: Any, ) -> LMMCompletionStream: ... @@ -33,6 +34,7 @@ async def openai_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, stream: Callable[[LMMCompletionStreamingUpdate], None], + **extra: Any, ) -> LMMCompletionMessage: ... @@ -42,6 +44,8 @@ async def openai_lmm_completion( context: list[LMMCompletionMessage], tools: Toolbox | None = None, output: Literal["text", "json"] = "text", + stream: Literal[False] = False, + **extra: Any, ) -> LMMCompletionMessage: ... @@ -51,14 +55,15 @@ async def openai_lmm_completion( tools: Toolbox | None = None, output: Literal["text", "json"] = "text", stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False, + **extra: Any, ) -> LMMCompletionStream | LMMCompletionMessage: client: OpenAIClient = ctx.dependency(OpenAIClient) - config: OpenAIChatConfig + config: OpenAIChatConfig = ctx.state(OpenAIChatConfig).updated(**extra) match output: case "text": - config = ctx.state(OpenAIChatConfig).updated(response_format={"type": "text"}) + config = config.updated(response_format={"type": "text"}) case "json": - config = ctx.state(OpenAIChatConfig).updated(response_format={"type": "json_object"}) + config = config.updated(response_format={"type": "json_object"}) messages: list[ChatCompletionMessageParam] = [ _convert_message(config=config, message=message) for message in context ] diff --git a/src/draive/parameters/data.py b/src/draive/parameters/data.py index 6426ba9..bdebae9 100644 --- a/src/draive/parameters/data.py +++ b/src/draive/parameters/data.py @@ -238,6 +238,7 @@ def validator( def from_dict( cls, value: dict[str, Any], + /, ) -> Self: return cls.validated(**value) @@ -256,4 +257,8 @@ def updated( /, **parameters: Any, ) -> Self: - return self.__class__.validated(**{**vars(self), **parameters}) + if parameters: + return self.__class__.validated(**{**vars(self), **parameters}) + + else: + return self diff --git a/src/draive/tools/toolbox.py b/src/draive/tools/toolbox.py index 6fe08c1..54ec566 100644 --- a/src/draive/tools/toolbox.py +++ b/src/draive/tools/toolbox.py @@ -1,6 +1,7 @@ from json import loads -from typing import Any, final +from typing import Any, Literal, final +from draive.helpers import freeze from draive.parameters import ToolSpecification from draive.tools import Tool from draive.tools.errors import ToolException @@ -17,12 +18,24 @@ class Toolbox: def __init__( self, *tools: AnyTool, - suggested: AnyTool | None = None, + suggest: AnyTool | Literal[True] | None = None, ) -> None: - self._suggested_tool: AnyTool | None = suggested self._tools: dict[str, AnyTool] = {tool.name: tool for tool in tools} - if suggested := suggested: - self._tools[suggested.name] = suggested + self.suggest_tools: bool + self._suggested_tool: AnyTool | None + match suggest: + case None: + self.suggest_tools = False + self._suggested_tool = None + case True: + self.suggest_tools = True if self._tools else False + self._suggested_tool = None + case tool: + self.suggest_tools = True + self._suggested_tool = tool + self._tools[tool.name] = tool + + freeze(self) @property def suggested_tool_name(self) -> str | None: diff --git a/src/draive/types/model.py b/src/draive/types/model.py index 717ad3d..f4f9c02 100644 --- a/src/draive/types/model.py +++ b/src/draive/types/model.py @@ -33,6 +33,7 @@ def specification(cls) -> str: def from_json( cls, value: str | bytes, + /, decoder: type[json.JSONDecoder] = json.JSONDecoder, ) -> Self: try: diff --git a/tests/test_cache.py b/tests/test_cache.py index 37299ef..e04d34b 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -70,7 +70,7 @@ def randomized(_: str, /) -> int: def test_returns_fresh_value_with_expiration_time_exceed(fake_random: Callable[[], int]): - @cache(expiration=0.02) + @cache(expiration=0.01) def randomized(_: str, /) -> int: return fake_random() @@ -129,7 +129,7 @@ async def randomized(_: str, /) -> int: async def test_async_returns_fresh_value_with_expiration_time_exceed( fake_random: Callable[[], int], ): - @cache(expiration=0.02) + @cache(expiration=0.01) async def randomized(_: str, /) -> int: return fake_random()