From 4ea4d721712ca32b79263064aeb3e9314e447c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Kali=C5=84ski?= Date: Fri, 17 May 2024 18:35:50 +0200 Subject: [PATCH] Rework LMM interface --- constraints | 17 +- examples/BasicConversation.ipynb | 21 +- examples/BasicModelGeneration.ipynb | 6 +- pyproject.toml | 2 +- src/draive/__init__.py | 92 ++-- src/draive/agents/flow.py | 6 +- src/draive/agents/state.py | 15 +- src/draive/conversation/__init__.py | 11 +- src/draive/conversation/call.py | 84 ++-- src/draive/conversation/completion.py | 68 +-- src/draive/conversation/lmm.py | 341 +++++++++---- src/draive/conversation/message.py | 16 - src/draive/conversation/model.py | 67 +++ src/draive/conversation/state.py | 6 +- src/draive/generation/image/call.py | 6 +- src/draive/generation/image/generator.py | 5 +- src/draive/generation/model/call.py | 20 +- src/draive/generation/model/generator.py | 15 +- src/draive/generation/model/lmm.py | 150 +++--- src/draive/generation/model/state.py | 2 - src/draive/generation/text/call.py | 15 +- src/draive/generation/text/generator.py | 10 +- src/draive/generation/text/lmm.py | 102 ++-- src/draive/generation/text/state.py | 2 - src/draive/helpers/__init__.py | 2 + src/draive/helpers/mimic.py | 1 - src/draive/helpers/stream.py | 82 +++ src/draive/lmm/__init__.py | 27 +- src/draive/lmm/call.py | 90 ++-- src/draive/lmm/completion.py | 65 --- src/draive/{tools => lmm}/errors.py | 0 src/draive/lmm/invocation.py | 65 +++ src/draive/lmm/message.py | 30 -- src/draive/lmm/state.py | 59 ++- src/draive/{tools => lmm}/tool.py | 202 ++++---- src/draive/lmm/toolbox.py | 157 ++++++ src/draive/metrics/function.py | 4 +- src/draive/metrics/trace.py | 5 +- src/draive/mistral/__init__.py | 4 +- src/draive/mistral/chat_response.py | 123 ----- src/draive/mistral/chat_stream.py | 33 -- src/draive/mistral/chat_tools.py | 132 ----- src/draive/mistral/client.py | 1 + src/draive/mistral/config.py | 1 - src/draive/mistral/lmm.py | 409 ++++++++------- src/draive/mistral/models.py | 26 +- src/draive/openai/__init__.py | 4 +- src/draive/openai/chat_response.py | 126 ----- src/draive/openai/chat_stream.py | 161 ------ src/draive/openai/chat_tools.py | 209 -------- src/draive/openai/config.py | 1 - src/draive/openai/lmm.py | 609 +++++++++++++++-------- src/draive/scope/access.py | 148 +++++- src/draive/similarity/__init__.py | 10 +- src/draive/similarity/mmr.py | 26 +- src/draive/similarity/score.py | 24 + src/draive/similarity/search.py | 38 ++ src/draive/similarity/similarity.py | 32 -- src/draive/tools/__init__.py | 19 - src/draive/tools/state.py | 33 -- src/draive/tools/toolbox.py | 81 --- src/draive/tools/update.py | 18 - src/draive/types/__init__.py | 54 +- src/draive/types/audio.py | 8 +- src/draive/types/images.py | 8 +- src/draive/types/instruction.py | 86 ++++ src/draive/types/lmm.py | 110 ++++ src/draive/types/memory.py | 10 +- src/draive/types/multimodal.py | 220 ++++---- src/draive/types/tool_status.py | 19 + src/draive/types/video.py | 8 +- src/draive/utils/__init__.py | 6 +- src/draive/utils/early_exit.py | 89 ---- src/draive/utils/stream.py | 83 +-- tests/test_model.py | 116 +++-- tests/test_tool_call.py | 22 +- 76 files changed, 2534 insertions(+), 2441 deletions(-) delete mode 100644 src/draive/conversation/message.py create mode 100644 src/draive/conversation/model.py create mode 100644 src/draive/helpers/stream.py delete mode 100644 src/draive/lmm/completion.py rename src/draive/{tools => lmm}/errors.py (100%) create mode 100644 src/draive/lmm/invocation.py delete mode 100644 src/draive/lmm/message.py rename src/draive/{tools => lmm}/tool.py (56%) create mode 100644 src/draive/lmm/toolbox.py delete mode 100644 src/draive/mistral/chat_response.py delete mode 100644 src/draive/mistral/chat_stream.py delete mode 100644 src/draive/mistral/chat_tools.py delete mode 100644 src/draive/openai/chat_response.py delete mode 100644 src/draive/openai/chat_stream.py delete mode 100644 src/draive/openai/chat_tools.py create mode 100644 src/draive/similarity/score.py create mode 100644 src/draive/similarity/search.py delete mode 100644 src/draive/similarity/similarity.py delete mode 100644 src/draive/tools/__init__.py delete mode 100644 src/draive/tools/state.py delete mode 100644 src/draive/tools/toolbox.py delete mode 100644 src/draive/tools/update.py create mode 100644 src/draive/types/instruction.py create mode 100644 src/draive/types/lmm.py create mode 100644 src/draive/types/tool_status.py delete mode 100644 src/draive/utils/early_exit.py diff --git a/constraints b/constraints index a35ee9b..cb53be8 100644 --- a/constraints +++ b/constraints @@ -7,6 +7,7 @@ anyio==4.3.0 # httpx # openai bandit==1.7.8 + # via draive (pyproject.toml) certifi==2024.2.2 # via # httpcore @@ -23,7 +24,9 @@ h11==0.14.0 httpcore==1.0.5 # via httpx httpx==0.27.0 - # via openai + # via + # draive (pyproject.toml) + # openai idna==3.7 # via # anyio @@ -38,7 +41,9 @@ mdurl==0.1.2 nodeenv==1.8.0 # via pyright numpy==1.26.4 + # via draive (pyproject.toml) openai==1.30.1 + # via draive (pyproject.toml) packaging==24.0 # via pytest pbr==6.0.0 @@ -46,18 +51,24 @@ pbr==6.0.0 pluggy==1.5.0 # via pytest pydantic==2.7.1 - # via openai + # via + # draive (pyproject.toml) + # openai pydantic-core==2.18.2 # via pydantic pygments==2.18.0 # via rich pyright==1.1.364 + # via draive (pyproject.toml) pytest==7.4.4 # via + # draive (pyproject.toml) # pytest-asyncio # pytest-cov pytest-asyncio==0.23.7 + # via draive (pyproject.toml) pytest-cov==4.1.0 + # via draive (pyproject.toml) pyyaml==6.0.1 # via bandit regex==2024.5.15 @@ -67,6 +78,7 @@ requests==2.32.2 rich==13.7.1 # via bandit ruff==0.4.5 + # via draive (pyproject.toml) setuptools==70.0.0 # via nodeenv sniffio==1.3.1 @@ -77,6 +89,7 @@ sniffio==1.3.1 stevedore==5.2.0 # via bandit tiktoken==0.7.0 + # via draive (pyproject.toml) tqdm==4.66.4 # via openai typing-extensions==4.11.0 diff --git a/examples/BasicConversation.ipynb b/examples/BasicConversation.ipynb index ee777cf..9bfc262 100644 --- a/examples/BasicConversation.ipynb +++ b/examples/BasicConversation.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -47,10 +47,15 @@ "output_type": "stream", "text": [ "{\n", - " \"role\": \"assistant\",\n", - " \"content\": \"The current UTC time and date is Wednesday, 17 April 2024, 13:13:27.\",\n", + " \"identifier\": \"691610408b1c4a4ab1aecdf78da128cb\",\n", + " \"role\": \"model\",\n", " \"author\": null,\n", - " \"created\": \"2024-04-17T13:13:28.396044+00:00\"\n", + " \"created\": \"2024-05-23T07:46:42.001814+00:00\",\n", + " \"content\": {\n", + " \"elements\": [\n", + " \"The current UTC time and date is Thursday, 23 May 2024, 07:46:41.\"\n", + " ]\n", + " }\n", "}\n" ] } @@ -64,14 +69,14 @@ " Toolbox,\n", " conversation_completion,\n", " ctx,\n", - " openai_lmm_completion,\n", + " openai_lmm_invocation,\n", ")\n", "\n", "# initialize dependencies and configuration\n", "async with ctx.new(\n", " dependencies=[OpenAIClient], # use OpenAI client\n", " state=[\n", - " LMM(completion=openai_lmm_completion), # define used LMM\n", + " LMM(invocation=openai_lmm_invocation), # define used LMM\n", " OpenAIChatConfig(model=\"gpt-3.5-turbo-0125\"), # configure OpenAI model\n", " ],\n", "):\n", diff --git a/examples/BasicModelGeneration.ipynb b/examples/BasicModelGeneration.ipynb index 05066f6..102ba29 100644 --- a/examples/BasicModelGeneration.ipynb +++ b/examples/BasicModelGeneration.ipynb @@ -58,14 +58,14 @@ " OpenAIClient,\n", " ctx,\n", " generate_model,\n", - " openai_lmm_completion,\n", + " openai_lmm_invocation,\n", ")\n", "\n", "# initialize dependencies and configuration\n", "async with ctx.new(\n", " dependencies=[OpenAIClient], # use OpenAI client\n", " state=[\n", - " LMM(completion=openai_lmm_completion), # define used LMM\n", + " LMM(invocation=openai_lmm_invocation), # define used LMM\n", " OpenAIChatConfig(model=\"gpt-3.5-turbo-0125\"), # configure OpenAI model\n", " ],\n", "):\n", @@ -98,7 +98,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 7e732fc..32a373f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "draive" -version = "0.12.1" +version = "0.13.0" readme = "README.md" maintainers = [ {name = "Kacper KaliƄski", email = "kacper.kalinski@miquido.com"} diff --git a/src/draive/__init__.py b/src/draive/__init__.py index dbf524d..833fc90 100644 --- a/src/draive/__init__.py +++ b/src/draive/__init__.py @@ -10,8 +10,9 @@ from draive.conversation import ( Conversation, ConversationCompletion, - ConversationCompletionStream, ConversationMessage, + ConversationMessageChunk, + ConversationResponseStream, conversation_completion, lmm_conversation_completion, ) @@ -29,6 +30,7 @@ ) from draive.helpers import ( MISSING, + AsyncStream, Missing, freeze, getenv_bool, @@ -44,11 +46,13 @@ ) from draive.lmm import ( LMM, - LMMCompletion, - LMMCompletionStream, - LMMMessage, - LMMStreamingUpdate, - lmm_completion, + Tool, + Toolbox, + ToolCallContext, + ToolException, + ToolStatusStream, + lmm_invocation, + tool, ) from draive.metrics import ( Metric, @@ -64,7 +68,7 @@ MistralEmbeddingConfig, MistralException, mistral_embed_text, - mistral_lmm_completion, + mistral_lmm_invocation, ) from draive.openai import ( OpenAIChatConfig, @@ -74,7 +78,7 @@ OpenAIImageGenerationConfig, openai_embed_text, openai_generate_image, - openai_lmm_completion, + openai_lmm_invocation, openai_tokenize_text, ) from draive.parameters import Argument, Field, ParameterPath @@ -84,47 +88,44 @@ ScopeState, ctx, ) -from draive.similarity import mmr_similarity, similarity +from draive.similarity import mmr_similarity_search, similarity_score, similarity_search from draive.splitters import split_text from draive.tokenization import TextTokenizer, Tokenization, count_text_tokens, tokenize_text -from draive.tools import ( - Tool, - Toolbox, - ToolCallContext, - ToolCallStatus, - ToolCallUpdate, - ToolException, - ToolsUpdatesContext, - tool, -) from draive.types import ( AudioBase64Content, AudioContent, + AudioDataContent, AudioURLContent, ImageBase64Content, ImageContent, + ImageDataContent, ImageURLContent, + Instruction, + LMMCompletion, + LMMCompletionChunk, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMOutputStream, + LMMOutputStreamChunk, + LMMToolRequest, + LMMToolResponse, Memory, Model, MultimodalContent, ReadOnlyMemory, State, + ToolCallStatus, VideoBase64Content, VideoContent, + VideoDataContent, VideoURLContent, - has_media, - is_multimodal_content, - merge_multimodal_content, - multimodal_content_string, ) from draive.utils import ( - AsyncStream, AsyncStreamTask, - allowing_early_exit, auto_retry, cache, traced, - with_early_exit, ) __all__ = [ @@ -134,7 +135,6 @@ "AgentFlow", "AgentScratchpad", "AgentState", - "allowing_early_exit", "Argument", "AsyncStream", "AsyncStream", @@ -142,6 +142,7 @@ "AsyncStreamTask", "AudioBase64Content", "AudioContent", + "AudioDataContent", "AudioURLContent", "auto_retry", "BaseAgent", @@ -151,7 +152,8 @@ "Conversation", "Conversation", "ConversationCompletion", - "ConversationCompletionStream", + "ConversationMessageChunk", + "ConversationResponseStream", "ConversationMessage", "count_text_tokens", "ctx", @@ -168,24 +170,28 @@ "getenv_float", "getenv_int", "getenv_str", - "has_media", "ImageBase64Content", "ImageContent", + "ImageDataContent", "ImageGeneration", "ImageGenerator", "ImageURLContent", + "Instruction", "is_missing", - "is_multimodal_content", - "lmm_completion", + "lmm_invocation", "lmm_conversation_completion", "LMM", "LMMCompletion", - "LMMMessage", - "LMMCompletionStream", - "LMMStreamingUpdate", + "LMMCompletionChunk", + "LMMContextElement", + "LMMInput", + "LMMInstruction", + "LMMOutputStream", + "LMMOutputStreamChunk", + "LMMToolRequest", + "LMMToolResponse", "load_env", "Memory", - "merge_multimodal_content", "Metric", "metrics_log_reporter", "MetricsTrace", @@ -194,21 +200,20 @@ "Missing", "MISSING", "mistral_embed_text", - "mistral_lmm_completion", + "mistral_lmm_invocation", "MistralChatConfig", "MistralClient", "MistralEmbeddingConfig", "MistralException", - "mmr_similarity", + "mmr_similarity_search", "Model", "ModelGeneration", "ModelGenerator", - "multimodal_content_string", "MultimodalContent", "not_missing", "openai_embed_text", "openai_generate_image", - "openai_lmm_completion", + "openai_lmm_invocation", "openai_tokenize_text", "OpenAIChatConfig", "OpenAIClient", @@ -221,7 +226,8 @@ "ScopeDependency", "ScopeState", "setup_logging", - "similarity", + "similarity_score", + "similarity_search", "split_sequence", "split_text", "State", @@ -237,14 +243,14 @@ "Toolbox", "ToolCallContext", "ToolCallStatus", - "ToolCallUpdate", + "ToolCallStatus", "ToolException", "ToolException", - "ToolsUpdatesContext", + "ToolStatusStream", "traced", "VideoBase64Content", "VideoContent", + "VideoDataContent", "VideoURLContent", "when_missing", - "with_early_exit", ] diff --git a/src/draive/agents/flow.py b/src/draive/agents/flow.py index 8c4a75c..14a792d 100644 --- a/src/draive/agents/flow.py +++ b/src/draive/agents/flow.py @@ -7,7 +7,7 @@ from draive.helpers import freeze from draive.parameters import ParametrizedData from draive.scope import ctx -from draive.types import MultimodalContent, merge_multimodal_content +from draive.types import MultimodalContent __all__ = [ "AgentFlow", @@ -42,7 +42,7 @@ async def __call__( with ctx.updated(current_scratchpad): match agent: case [*agents]: - merged_note: MultimodalContent = merge_multimodal_content( + merged_note: MultimodalContent = MultimodalContent.of( *[ scratchpad_note for scratchpad_note in await gather( @@ -61,4 +61,4 @@ async def __call__( current_scratchpad = current_scratchpad.extended(scratchpad_note) scratchpad_notes.append(scratchpad_note) - return merge_multimodal_content(*scratchpad_notes) + return MultimodalContent.of(*scratchpad_notes) diff --git a/src/draive/agents/state.py b/src/draive/agents/state.py index d690175..1d2b0d2 100644 --- a/src/draive/agents/state.py +++ b/src/draive/agents/state.py @@ -4,7 +4,7 @@ from draive.parameters import ParametrizedData from draive.scope import ctx -from draive.types import MultimodalContent, MultimodalContentItem, State +from draive.types import MultimodalContent, State __all__ = [ "AgentState", @@ -24,13 +24,12 @@ def prepare( ) -> Self: match content: case None: - return cls(content=()) - case [*items]: - return cls(content=tuple(items)) + return cls(content=MultimodalContent.of()) + case item: - return cls(content=(item,)) + return cls(content=item) - content: tuple[MultimodalContentItem, ...] = () + content: MultimodalContent = MultimodalContent.of() def extended( self, @@ -39,10 +38,8 @@ def extended( match content: case None: return self - case [*items]: - return self.__class__(content=(*self.content, *items)) case item: - return self.__class__(content=(*self.content, item)) + return self.__class__(content=MultimodalContent.of(*self.content, item)) class AgentState[State: ParametrizedData]: diff --git a/src/draive/conversation/__init__.py b/src/draive/conversation/__init__.py index 6bf87a6..2258a46 100644 --- a/src/draive/conversation/__init__.py +++ b/src/draive/conversation/__init__.py @@ -1,14 +1,19 @@ from draive.conversation.call import conversation_completion -from draive.conversation.completion import ConversationCompletion, ConversationCompletionStream +from draive.conversation.completion import ConversationCompletion from draive.conversation.lmm import lmm_conversation_completion -from draive.conversation.message import ConversationMessage +from draive.conversation.model import ( + ConversationMessage, + ConversationMessageChunk, + ConversationResponseStream, +) from draive.conversation.state import Conversation __all__ = [ "conversation_completion", "Conversation", "ConversationCompletion", - "ConversationCompletionStream", + "ConversationMessageChunk", + "ConversationResponseStream", "ConversationMessage", "lmm_conversation_completion", ] diff --git a/src/draive/conversation/call.py b/src/draive/conversation/call.py index 5b155c8..e609b49 100644 --- a/src/draive/conversation/call.py +++ b/src/draive/conversation/call.py @@ -1,12 +1,11 @@ -from collections.abc import Callable +from collections.abc import Sequence from typing import Literal, overload -from draive.conversation.completion import ConversationCompletionStream -from draive.conversation.message import ConversationMessage, ConversationStreamingUpdate +from draive.conversation.model import ConversationMessage, ConversationResponseStream from draive.conversation.state import Conversation +from draive.lmm import AnyTool, Toolbox from draive.scope import ctx -from draive.tools import Toolbox -from draive.types import Memory, MultimodalContent +from draive.types import Instruction, Memory, MultimodalContent, MultimodalContentElement __all__ = [ "conversation_completion", @@ -16,66 +15,49 @@ @overload async def conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, stream: Literal[True], -) -> ConversationCompletionStream: ... +) -> ConversationResponseStream: ... @overload async def conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None], + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: Literal[False] = False, ) -> ConversationMessage: ... @overload async def conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, -) -> ConversationMessage: ... + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool, +) -> ConversationResponseStream | ConversationMessage: ... async def conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None] | bool = False, -) -> ConversationCompletionStream | ConversationMessage: + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool = False, +) -> ConversationResponseStream | ConversationMessage: conversation: Conversation = ctx.state(Conversation) - - match stream: - case False: - return await conversation.completion( - instruction=instruction, - input=input, - memory=memory or conversation.memory, - tools=tools or conversation.tools, - ) - case True: - return await conversation.completion( - instruction=instruction, - input=input, - memory=memory or conversation.memory, - tools=tools or conversation.tools, - stream=True, - ) - case progress: - return await conversation.completion( - instruction=instruction, - input=input, - memory=memory or conversation.memory, - tools=tools or conversation.tools, - stream=progress, - ) + return await conversation.completion( + instruction=instruction, + input=input, + memory=memory or conversation.memory, + tools=tools, + stream=stream, + ) diff --git a/src/draive/conversation/completion.py b/src/draive/conversation/completion.py index 74ebc65..0953b1b 100644 --- a/src/draive/conversation/completion.py +++ b/src/draive/conversation/completion.py @@ -1,63 +1,65 @@ -from collections.abc import Callable -from typing import Literal, Protocol, overload, runtime_checkable +from collections.abc import Sequence +from typing import Any, Literal, Protocol, overload, runtime_checkable -from draive.conversation.message import ( - ConversationMessage, - ConversationStreamingUpdate, +from draive.conversation.model import ConversationMessage, ConversationResponseStream +from draive.lmm import AnyTool, Toolbox +from draive.types import ( + Instruction, + Memory, + MultimodalContent, + MultimodalContentElement, ) -from draive.lmm import LMMCompletionStream -from draive.tools import Toolbox -from draive.types import Memory, MultimodalContent __all__ = [ - "ConversationCompletionStream", "ConversationCompletion", ] -ConversationCompletionStream = LMMCompletionStream - - @runtime_checkable class ConversationCompletion(Protocol): @overload async def __call__( self, *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, stream: Literal[True], - ) -> ConversationCompletionStream: ... + **extra: Any, + ) -> ConversationResponseStream: ... @overload async def __call__( self, *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None], + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: Literal[False] = False, + **extra: Any, ) -> ConversationMessage: ... @overload async def __call__( self, *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - ) -> ConversationMessage: ... + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool, + **extra: Any, + ) -> ConversationResponseStream | ConversationMessage: ... async def __call__( # noqa: PLR0913 self, *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None] | bool = False, - ) -> ConversationCompletionStream | ConversationMessage: ... + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool = False, + **extra: Any, + ) -> ConversationResponseStream | ConversationMessage: ... diff --git a/src/draive/conversation/lmm.py b/src/draive/conversation/lmm.py index 064ab12..24f85c1 100644 --- a/src/draive/conversation/lmm.py +++ b/src/draive/conversation/lmm.py @@ -1,16 +1,34 @@ -from collections.abc import Callable +from collections.abc import AsyncGenerator, Sequence from datetime import UTC, datetime -from typing import Literal, overload +from typing import Any, Literal, overload +from uuid import uuid4 -from draive.conversation.completion import ConversationCompletionStream -from draive.conversation.message import ( +from draive.conversation.model import ( ConversationMessage, - ConversationStreamingUpdate, + ConversationMessageChunk, + ConversationResponseStream, ) -from draive.lmm import LMMMessage, lmm_completion -from draive.tools import Toolbox -from draive.types import Memory, MultimodalContent -from draive.utils import AsyncStreamTask +from draive.lmm import ( + AnyTool, + Toolbox, + lmm_invocation, +) +from draive.scope import ctx +from draive.types import ( + Instruction, + LMMCompletion, + LMMCompletionChunk, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMToolRequests, + LMMToolResponse, + Memory, + MultimodalContent, + MultimodalContentElement, + ReadOnlyMemory, +) +from draive.types.tool_status import ToolCallStatus __all__: list[str] = [ "lmm_conversation_completion", @@ -20,137 +38,242 @@ @overload async def lmm_conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, stream: Literal[True], -) -> ConversationCompletionStream: ... + **extra: Any, +) -> ConversationResponseStream: ... @overload async def lmm_conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None], + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: Literal[False] = False, + **extra: Any, ) -> ConversationMessage: ... @overload async def lmm_conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, -) -> ConversationMessage: ... + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool, + **extra: Any, +) -> ConversationResponseStream | ConversationMessage: ... async def lmm_conversation_completion( *, - instruction: str, - input: ConversationMessage | MultimodalContent, # noqa: A002 - memory: Memory[ConversationMessage] | None = None, - tools: Toolbox | None = None, - stream: Callable[[ConversationStreamingUpdate], None] | bool = False, -) -> ConversationCompletionStream | ConversationMessage: - system_message: LMMMessage = LMMMessage( - role="system", - content=instruction, - ) - user_message: ConversationMessage - if isinstance(input, ConversationMessage): - user_message = input + instruction: Instruction | str, + input: ConversationMessage | MultimodalContent | MultimodalContentElement, # noqa: A002 + memory: Memory[ConversationMessage] | Sequence[ConversationMessage] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + stream: bool = False, + **extra: Any, +) -> ConversationResponseStream | ConversationMessage: + with ctx.nested( + "lmm_conversation_completion", + ): + toolbox: Toolbox + match tools: + case None: + toolbox = Toolbox() - else: - user_message = ConversationMessage( - created=datetime.now(UTC), - role="user", - content=input, - ) + case Toolbox() as tools: + toolbox = tools - context: list[LMMMessage] + case [*tools]: + toolbox = Toolbox(*tools) - if memory: - context = [ - system_message, - *await memory.recall(), - user_message, + context: list[LMMContextElement] = [ + LMMInstruction.of(instruction), ] - else: - context = [ - system_message, - user_message, - ] + conversation_memory: Memory[ConversationMessage] + match memory: + case None: + conversation_memory = ReadOnlyMemory() - match stream: - case True: + case Memory() as memory: + context.extend( + message.as_lmm_context_element() for message in await memory.recall() + ) + conversation_memory = memory - async def stream_task( - update: Callable[[ConversationStreamingUpdate], None], - ) -> None: - nonlocal memory - completion: LMMMessage = await lmm_completion( - context=context, - tools=tools, - stream=update, + case [*memory_messages]: + context.extend(message.as_lmm_context_element() for message in memory_messages) + conversation_memory = ReadOnlyMemory(elements=memory_messages) + + request_message: ConversationMessage + match input: + case ConversationMessage() as message: + context.append(LMMInput.of(message.content)) + request_message = message + + case content: + context.append(LMMInput.of(content)) + request_message = ConversationMessage( + role="user", + created=datetime.now(UTC), + content=MultimodalContent.of(content), ) + + if stream: + return ctx.stream( + generator=_lmm_conversation_completion_stream( + request_message=request_message, + conversation_memory=conversation_memory, + context=context, + toolbox=toolbox, + **extra, + ), + ) + else: + return await _lmm_conversation_completion( + request_message=request_message, + conversation_memory=conversation_memory, + context=context, + toolbox=toolbox, + **extra, + ) + + +async def _lmm_conversation_completion( + request_message: ConversationMessage, + conversation_memory: Memory[ConversationMessage], + context: list[LMMContextElement], + toolbox: Toolbox, + **extra: Any, +) -> ConversationMessage: + for recursion_level in toolbox.call_range: + match await lmm_invocation( + context=context, + tools=toolbox.available_tools(recursion_level=recursion_level), + require_tool=toolbox.tool_suggestion(recursion_level=recursion_level), + output="text", + stream=False, + **extra, + ): + case LMMCompletion() as completion: response_message: ConversationMessage = ConversationMessage( + role="model", created=datetime.now(UTC), - role=completion.role, content=completion.content, ) - if memory := memory: - await memory.remember( - [ - user_message, - response_message, - ], - ) + await conversation_memory.remember( + request_message, + response_message, + ) + return response_message - return AsyncStreamTask(job=stream_task) + case LMMToolRequests() as tool_requests: + context.append(tool_requests) + responses: list[LMMToolResponse] = await toolbox.respond(tool_requests) - case False: - completion: LMMMessage = await lmm_completion( - context=context, - tools=tools, - ) - response_message: ConversationMessage = ConversationMessage( - created=datetime.now(UTC), - role=completion.role, - content=completion.content, - ) - if memory := memory: - await memory.remember( - [ - user_message, + if direct_content := [ + response.content for response in responses if response.direct + ]: + response_message: ConversationMessage = ConversationMessage( + role="model", + created=datetime.now(UTC), + content=MultimodalContent.of(*direct_content), + ) + await conversation_memory.remember( + request_message, response_message, - ], - ) + ) + return response_message - return response_message + else: + context.extend(responses) - case update: - completion: LMMMessage = await lmm_completion( - context=context, - tools=tools, - stream=update, - ) - response_message: ConversationMessage = ConversationMessage( - created=datetime.now(UTC), - role=completion.role, - content=completion.content, - ) - if memory := memory: - await memory.remember( - [ - user_message, - response_message, - ], - ) + # fail if we have not provided a result until this point + raise RuntimeError("Failed to produce conversation completion") + + +async def _lmm_conversation_completion_stream( + request_message: ConversationMessage, + conversation_memory: Memory[ConversationMessage], + context: list[LMMContextElement], + toolbox: Toolbox, + **extra: Any, +) -> AsyncGenerator[ConversationMessageChunk | ToolCallStatus, None]: + response_identifier: str = uuid4().hex + response_content: MultimodalContent = MultimodalContent.of() # empty - return response_message + for recursion_level in toolbox.call_range: + async for part in await lmm_invocation( + context=context, + tools=toolbox.available_tools(recursion_level=recursion_level), + require_tool=toolbox.tool_suggestion(recursion_level=recursion_level), + output="text", + stream=True, + **extra, + ): + match part: + case LMMCompletionChunk() as chunk: + response_content = response_content.extending(chunk.content) + + yield ConversationMessageChunk( + identifier=response_identifier, + content=chunk.content, + ) + # keep yielding parts + + case LMMToolRequests() as tool_requests: + assert ( # nosec: B101 + not response_content + ), "Tools and completion message should not be used at the same time" + + responses: list[LMMToolResponse] = [] + async for update in toolbox.stream(tool_requests): + match update: + case LMMToolResponse() as response: + responses.append(response) + + case ToolCallStatus() as status: + yield status + + assert len(responses) == len( # nosec: B101 + tool_requests.requests + ), "Tool responses count should match requests count" + + if direct_content := [ + response.content for response in responses if response.direct + ]: + response_content = MultimodalContent.of(*direct_content) + yield ConversationMessageChunk( + identifier=response_identifier, + content=response_content, + ) + # exit the loop - we have final result + + else: + context.extend([tool_requests, *responses]) + break # request lmm again with tool results using outer loop + else: + break # exit the loop with result + + if response_content: + # remember messages when finishing stream + await conversation_memory.remember( + request_message, + ConversationMessage( + identifier=response_identifier, + role="model", + created=datetime.now(UTC), + content=response_content.joining_texts(joiner=""), + ), + ) + else: + # fail if we have not provided a result until this point + raise RuntimeError("Failed to produce conversation completion") diff --git a/src/draive/conversation/message.py b/src/draive/conversation/message.py deleted file mode 100644 index f98c04e..0000000 --- a/src/draive/conversation/message.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime - -from draive.lmm import LMMMessage, LMMStreamingUpdate - -__all__ = [ - "ConversationMessage", - "ConversationStreamingUpdate", -] - - -class ConversationMessage(LMMMessage): - author: str | None = None - created: datetime | None = None - - -ConversationStreamingUpdate = LMMStreamingUpdate diff --git a/src/draive/conversation/model.py b/src/draive/conversation/model.py new file mode 100644 index 0000000..0ff57fb --- /dev/null +++ b/src/draive/conversation/model.py @@ -0,0 +1,67 @@ +from collections.abc import AsyncIterator +from datetime import datetime +from typing import Literal, Self +from uuid import uuid4 + +from draive.parameters import Field +from draive.types import ( + LMMCompletion, + LMMContextElement, + LMMInput, + Model, + MultimodalContent, + MultimodalContentElement, + ToolCallStatus, +) + +__all__ = [ + "ConversationMessage", + "ConversationMessageChunk", + "ConversationResponseStream", +] + + +class ConversationMessage(Model): + @classmethod + def user( + cls, + content: MultimodalContent | MultimodalContentElement, + identifier: str | None = None, + author: str | None = None, + created: datetime | None = None, + ) -> Self: + return cls( + identifier=identifier or uuid4().hex, + role="user", + author=author, + created=created, + content=MultimodalContent.of(content), + ) + + identifier: str = Field(default_factory=lambda: uuid4().hex) + role: Literal["user", "model"] + author: str | None = None + created: datetime | None = None + content: MultimodalContent + + def as_lmm_context_element(self) -> LMMContextElement: + match self.role: + case "user": + return LMMInput.of(self.content) + + case "model": + return LMMCompletion.of(self.content) + + def __bool__(self) -> bool: + return bool(self.content) + + +class ConversationMessageChunk(Model): + identifier: str + content: MultimodalContent + + def __bool__(self) -> bool: + return bool(self.content) + + +ConversationResponseStream = AsyncIterator[ConversationMessageChunk | ToolCallStatus] diff --git a/src/draive/conversation/state.py b/src/draive/conversation/state.py index bd6ca1a..cc543ce 100644 --- a/src/draive/conversation/state.py +++ b/src/draive/conversation/state.py @@ -1,9 +1,6 @@ from draive.conversation.completion import ConversationCompletion from draive.conversation.lmm import lmm_conversation_completion -from draive.conversation.message import ( - ConversationMessage, -) -from draive.tools import Toolbox +from draive.conversation.model import ConversationMessage from draive.types import Memory, State __all__: list[str] = [ @@ -14,4 +11,3 @@ class Conversation(State): completion: ConversationCompletion = lmm_conversation_completion memory: Memory[ConversationMessage] | None = None - tools: Toolbox | None = None diff --git a/src/draive/generation/image/call.py b/src/draive/generation/image/call.py index 7e7614c..637a72d 100644 --- a/src/draive/generation/image/call.py +++ b/src/draive/generation/image/call.py @@ -2,7 +2,7 @@ from draive.generation.image.state import ImageGeneration from draive.scope import ctx -from draive.types import ImageContent +from draive.types import ImageContent, Instruction, MultimodalContent, MultimodalContentElement __all__ = [ "generate_image", @@ -11,10 +11,12 @@ async def generate_image( *, - instruction: str, + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement | None = None, # noqa: A002 **extra: Any, ) -> ImageContent: return await ctx.state(ImageGeneration).generate( instruction=instruction, + input=input, **extra, ) diff --git a/src/draive/generation/image/generator.py b/src/draive/generation/image/generator.py index 0610238..b77e192 100644 --- a/src/draive/generation/image/generator.py +++ b/src/draive/generation/image/generator.py @@ -1,6 +1,6 @@ from typing import Any, Protocol, runtime_checkable -from draive.types import ImageContent +from draive.types import ImageContent, Instruction, MultimodalContent, MultimodalContentElement __all__ = [ "ImageGenerator", @@ -12,6 +12,7 @@ class ImageGenerator(Protocol): async def __call__( self, *, - instruction: str, + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement | None = None, # noqa: A002 **extra: Any, ) -> ImageContent: ... diff --git a/src/draive/generation/model/call.py b/src/draive/generation/model/call.py index 540b75e..f6e3fdb 100644 --- a/src/draive/generation/model/call.py +++ b/src/draive/generation/model/call.py @@ -1,10 +1,10 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any from draive.generation.model.state import ModelGeneration +from draive.lmm import AnyTool, Toolbox from draive.scope import ctx -from draive.tools import Toolbox -from draive.types import Model, MultimodalContent +from draive.types import Instruction, Model, MultimodalContent, MultimodalContentElement __all__ = [ "generate_model", @@ -15,20 +15,20 @@ async def generate_model[Generated: Model]( # noqa: PLR0913 generated: type[Generated], /, *, - instruction: str, - input: MultimodalContent, # noqa: A002 + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 schema_variable: str | None = None, - tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, Generated]] + | None = None, **extra: Any, ) -> Generated: - model_generation: ModelGeneration = ctx.state(ModelGeneration) - return await model_generation.generate( + return await ctx.state(ModelGeneration).generate( generated, instruction=instruction, input=input, schema_variable=schema_variable, - tools=tools or model_generation.tools, + tools=tools, examples=examples, **extra, ) diff --git a/src/draive/generation/model/generator.py b/src/draive/generation/model/generator.py index 5727b53..e75729b 100644 --- a/src/draive/generation/model/generator.py +++ b/src/draive/generation/model/generator.py @@ -1,8 +1,8 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any, Protocol, runtime_checkable -from draive.tools import Toolbox -from draive.types import Model, MultimodalContent +from draive.lmm import AnyTool, Toolbox +from draive.types import Instruction, Model, MultimodalContent, MultimodalContentElement __all__ = [ "ModelGenerator", @@ -16,10 +16,11 @@ async def __call__[Generated: Model]( # noqa: PLR0913 generated: type[Generated], /, *, - instruction: str, - input: MultimodalContent, # noqa: A002 + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 schema_variable: str | None = None, - tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, Generated]] + | None = None, **extra: Any, ) -> Generated: ... diff --git a/src/draive/generation/model/lmm.py b/src/draive/generation/model/lmm.py index 66dfea1..f3634f6 100644 --- a/src/draive/generation/model/lmm.py +++ b/src/draive/generation/model/lmm.py @@ -1,9 +1,20 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any -from draive.lmm import LMMMessage, lmm_completion -from draive.tools import Toolbox -from draive.types import Model, MultimodalContent +from draive.lmm import AnyTool, Toolbox, lmm_invocation +from draive.scope import ctx +from draive.types import ( + Instruction, + LMMCompletion, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMToolRequests, + LMMToolResponse, + Model, + MultimodalContent, + MultimodalContentElement, +) __all__: list[str] = [ "lmm_generate_model", @@ -14,76 +25,97 @@ async def lmm_generate_model[Generated: Model]( # noqa: PLR0913 generated: type[Generated], /, *, - instruction: str, - input: MultimodalContent, # noqa: A002 + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 schema_variable: str | None = None, - tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, Generated]] | None = None, + tools: Toolbox | Sequence[AnyTool] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, Generated]] + | None = None, **extra: Any, ) -> Generated: - system_message: LMMMessage - if variable := schema_variable: - system_message = LMMMessage( - role="system", - content=instruction.format(**{variable: generated.specification()}), - ) - - else: - system_message = LMMMessage( - role="system", - content=DEFAULT_INSTRUCTION.format( - instruction=instruction, - schema=generated.specification(), - ), - ) - - input_message: LMMMessage = LMMMessage( - role="user", - content=input, - ) - - context: list[LMMMessage] - - if examples: - context = [ - system_message, + with ctx.nested("lmm_generate_model"): + toolbox: Toolbox + match tools: + case None: + toolbox = Toolbox() + + case Toolbox() as tools: + toolbox = tools + + case [*tools]: + toolbox = Toolbox(*tools) + + generation_instruction: Instruction + match instruction: + case str(instruction): + generation_instruction = Instruction(instruction) + + case Instruction() as instruction: + generation_instruction = instruction + + instruction_message: LMMContextElement + if variable := schema_variable: + instruction_message = LMMInstruction.of( + generation_instruction.updated( + **{variable: generated.specification()}, + ), + ) + + else: + instruction_message = LMMInstruction.of( + generation_instruction.extended( + DEFAULT_INSTRUCTION_EXTENSION, + joiner="\n\n", + schema=generated.specification(), + ) + ) + + context: list[LMMContextElement] = [ + instruction_message, *[ message - for example in examples + for example in examples or [] for message in [ - LMMMessage( - role="user", - content=example[0], - ), - LMMMessage( - role="assistant", - content=example[1].as_json(indent=2), - ), + LMMInput.of(example[0]), + LMMCompletion.of(example[1].as_json(indent=2)), ] ], - input_message, + LMMInput.of(input), ] - else: - context = [ - system_message, - input_message, - ] + for recursion_level in toolbox.call_range: + match await lmm_invocation( + context=context, + tools=toolbox.available_tools(recursion_level=recursion_level), + require_tool=toolbox.tool_suggestion(recursion_level=recursion_level), + output="json", + stream=False, + **extra, + ): + case LMMCompletion() as completion: + return generated.from_json(completion.content.as_string()) + + case LMMToolRequests() as tool_requests: + context.append(tool_requests) + responses: list[LMMToolResponse] = await toolbox.respond(tool_requests) - completion: LMMMessage = await lmm_completion( - context=context, - tools=tools, - output="json", - stream=False, - **extra, - ) + if direct_responses := [response for response in responses if response.direct]: + # TODO: check if this join makes any sense, + # perhaps we could merge json objects instead? + return generated.from_json( + "".join( + *[response.content.as_string() for response in direct_responses] + ) + ) - return generated.from_json(completion.content_string) + else: + context.extend(responses) + # fail if we have not provided a result until this point + raise RuntimeError("Failed to produce conversation completion") -DEFAULT_INSTRUCTION: str = """\ -{instruction} +DEFAULT_INSTRUCTION_EXTENSION: str = """\ IMPORTANT! The result have to conform to the following JSON Schema: ``` diff --git a/src/draive/generation/model/state.py b/src/draive/generation/model/state.py index 651b9a3..40974fe 100644 --- a/src/draive/generation/model/state.py +++ b/src/draive/generation/model/state.py @@ -1,6 +1,5 @@ from draive.generation.model.generator import ModelGenerator from draive.generation.model.lmm import lmm_generate_model -from draive.tools import Toolbox from draive.types import State __all__ = [ @@ -10,4 +9,3 @@ class ModelGeneration(State): generate: ModelGenerator = lmm_generate_model - tools: Toolbox | None = None diff --git a/src/draive/generation/text/call.py b/src/draive/generation/text/call.py index 75ab8c1..5aa9b53 100644 --- a/src/draive/generation/text/call.py +++ b/src/draive/generation/text/call.py @@ -2,9 +2,9 @@ from typing import Any from draive.generation.text.state import TextGeneration +from draive.lmm import Toolbox from draive.scope import ctx -from draive.tools import Toolbox -from draive.types import MultimodalContent +from draive.types import Instruction, MultimodalContent, MultimodalContentElement __all__ = [ "generate_text", @@ -13,17 +13,16 @@ async def generate_text( *, - instruction: str, - input: MultimodalContent, # noqa: A002 + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, str]] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, str]] | None = None, **extra: Any, ) -> str: - text_generation: TextGeneration = ctx.state(TextGeneration) - return await text_generation.generate( + return await ctx.state(TextGeneration).generate( instruction=instruction, input=input, - tools=tools or text_generation.tools, + tools=tools, examples=examples, **extra, ) diff --git a/src/draive/generation/text/generator.py b/src/draive/generation/text/generator.py index 91a9bec..b20efde 100644 --- a/src/draive/generation/text/generator.py +++ b/src/draive/generation/text/generator.py @@ -1,8 +1,8 @@ from collections.abc import Iterable from typing import Any, Protocol, runtime_checkable -from draive.tools import Toolbox -from draive.types import MultimodalContent +from draive.lmm import Toolbox +from draive.types import Instruction, MultimodalContent, MultimodalContentElement __all__ = [ "TextGenerator", @@ -14,9 +14,9 @@ class TextGenerator(Protocol): async def __call__( self, *, - instruction: str, - input: MultimodalContent, # noqa: A002 + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, str]] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, str]] | None = None, **extra: Any, ) -> str: ... diff --git a/src/draive/generation/text/lmm.py b/src/draive/generation/text/lmm.py index 503cd08..f99d5d2 100644 --- a/src/draive/generation/text/lmm.py +++ b/src/draive/generation/text/lmm.py @@ -1,9 +1,19 @@ -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any -from draive.lmm import LMMMessage, lmm_completion -from draive.tools import Toolbox -from draive.types import MultimodalContent +from draive.lmm import AnyTool, Toolbox, lmm_invocation +from draive.scope import ctx +from draive.types import ( + Instruction, + LMMCompletion, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMToolRequests, + LMMToolResponse, + MultimodalContent, + MultimodalContentElement, +) __all__: list[str] = [ "lmm_generate_text", @@ -12,56 +22,60 @@ async def lmm_generate_text( *, - instruction: str, - input: MultimodalContent, # noqa: A002 - tools: Toolbox | None = None, - examples: Iterable[tuple[MultimodalContent, str]] | None = None, + instruction: Instruction | str, + input: MultimodalContent | MultimodalContentElement, # noqa: A002 + tools: Toolbox | Sequence[AnyTool] | None = None, + examples: Iterable[tuple[MultimodalContent | MultimodalContentElement, str]] | None = None, **extra: Any, ) -> str: - system_message: LMMMessage = LMMMessage( - role="system", - content=instruction, - ) - input_message: LMMMessage = LMMMessage( - role="user", - content=input, - ) + with ctx.nested("lmm_generate_text"): + toolbox: Toolbox + match tools: + case None: + toolbox = Toolbox() - context: list[LMMMessage] + case Toolbox() as tools: + toolbox = tools - if examples: - context = [ - system_message, + case [*tools]: + toolbox = Toolbox(*tools) + + context: list[LMMContextElement] = [ + LMMInstruction.of(instruction), *[ message - for example in examples + for example in examples or [] for message in [ - LMMMessage( - role="user", - content=example[0], - ), - LMMMessage( - role="assistant", - content=example[1], - ), + LMMInput.of(example[0]), + LMMCompletion.of(example[1]), ] ], - input_message, + LMMInput.of(input), ] - else: - context = [ - system_message, - input_message, - ] + for recursion_level in toolbox.call_range: + match await lmm_invocation( + context=context, + tools=toolbox.available_tools(recursion_level=recursion_level), + require_tool=toolbox.tool_suggestion(recursion_level=recursion_level), + output="text", + stream=False, + **extra, + ): + case LMMCompletion() as completion: + return completion.content.as_string() + + case LMMToolRequests() as tool_requests: + context.append(tool_requests) + responses: list[LMMToolResponse] = await toolbox.respond(tool_requests) + + if direct_responses := [response for response in responses if response.direct]: + return MultimodalContent.of( + *[response.content for response in direct_responses] + ).as_string() - completion: LMMMessage = await lmm_completion( - context=context, - tools=tools, - output="text", - stream=False, - **extra, - ) - generated: str = completion.content_string + else: + context.extend(responses) - return generated + # fail if we have not provided a result until this point + raise RuntimeError("Failed to produce conversation completion") diff --git a/src/draive/generation/text/state.py b/src/draive/generation/text/state.py index 30a8f03..dee5785 100644 --- a/src/draive/generation/text/state.py +++ b/src/draive/generation/text/state.py @@ -1,6 +1,5 @@ from draive.generation.text.generator import TextGenerator from draive.generation.text.lmm import lmm_generate_text -from draive.tools import Toolbox from draive.types import State __all__ = [ @@ -10,4 +9,3 @@ class TextGeneration(State): generate: TextGenerator = lmm_generate_text - tools: Toolbox | None = None diff --git a/src/draive/helpers/__init__.py b/src/draive/helpers/__init__.py index 1bdc349..ad17119 100644 --- a/src/draive/helpers/__init__.py +++ b/src/draive/helpers/__init__.py @@ -4,8 +4,10 @@ from draive.helpers.mimic import mimic_function from draive.helpers.missing import MISSING, Missing, is_missing, not_missing, when_missing from draive.helpers.split_sequence import split_sequence +from draive.helpers.stream import AsyncStream __all__ = [ + "AsyncStream", "freeze", "getenv_bool", "getenv_float", diff --git a/src/draive/helpers/mimic.py b/src/draive/helpers/mimic.py index 351be8d..1ab449a 100644 --- a/src/draive/helpers/mimic.py +++ b/src/draive/helpers/mimic.py @@ -40,7 +40,6 @@ def mimic( "__defaults__", "__kwdefaults__", "__globals__", - "__self__", ): try: setattr( diff --git a/src/draive/helpers/stream.py b/src/draive/helpers/stream.py new file mode 100644 index 0000000..e9fad8b --- /dev/null +++ b/src/draive/helpers/stream.py @@ -0,0 +1,82 @@ +from asyncio import AbstractEventLoop, CancelledError, Future, get_running_loop +from collections import deque +from collections.abc import AsyncIterator +from typing import Self + +__all__ = [ + "AsyncStream", +] + + +class AsyncStream[Element](AsyncIterator[Element]): + def __init__( + self, + loop: AbstractEventLoop | None = None, + ) -> None: + self._loop: AbstractEventLoop = loop or get_running_loop() + self._buffer: deque[Element] = deque() + self._waiting_queue: deque[Future[Element]] = deque() + self._finish_exception: BaseException | None = None + + def __del__(self) -> None: + while self._waiting_queue: + waiting: Future[Element] = self._waiting_queue.popleft() + if waiting.done(): + continue + else: + waiting.set_exception(CancelledError()) + + @property + def finished(self) -> bool: + return self._finish_exception is not None + + def send( + self, + element: Element, + ) -> None: + if self.finished: + raise RuntimeError("AsyncStream has been already finished") + + while self._waiting_queue: + assert not self._buffer # nosec: B101 + waiting: Future[Element] = self._waiting_queue.popleft() + if waiting.done(): + continue + else: + waiting.set_result(element) + break + else: + self._buffer.append(element) + + def finish( + self, + exception: BaseException | None = None, + ) -> None: + if self.finished: + raise RuntimeError("AsyncStream has been already finished") + self._finish_exception = exception or StopAsyncIteration() + if self._buffer: + assert self._waiting_queue is None # nosec: B101 + return # allow consuming buffer to the end + while self._waiting_queue: + waiting: Future[Element] = self._waiting_queue.popleft() + if waiting.done(): + continue + else: + waiting.set_exception(self._finish_exception) + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> Element: + if self._buffer: # use buffer first + return self._buffer.popleft() + if finish_exception := self._finish_exception: # check if finished + raise finish_exception + + # create new waiting future + future: Future[Element] = self._loop.create_future() + self._waiting_queue.append(future) + + # wait for the result + return await future diff --git a/src/draive/lmm/__init__.py b/src/draive/lmm/__init__.py index f0d399b..29fe2d0 100644 --- a/src/draive/lmm/__init__.py +++ b/src/draive/lmm/__init__.py @@ -1,16 +1,19 @@ -from draive.lmm.call import lmm_completion -from draive.lmm.completion import LMMCompletion, LMMCompletionStream -from draive.lmm.message import ( - LMMMessage, - LMMStreamingUpdate, -) -from draive.lmm.state import LMM +from draive.lmm.call import lmm_invocation +from draive.lmm.errors import ToolException +from draive.lmm.invocation import LMMInvocation +from draive.lmm.state import LMM, ToolCallContext, ToolStatusStream +from draive.lmm.tool import AnyTool, Tool, tool +from draive.lmm.toolbox import Toolbox __all__ = [ - "lmm_completion", + "AnyTool", + "lmm_invocation", "LMM", - "LMMCompletion", - "LMMMessage", - "LMMCompletionStream", - "LMMStreamingUpdate", + "LMMInvocation", + "Tool", + "Toolbox", + "ToolCallContext", + "ToolException", + "ToolStatusStream", + "tool", ] diff --git a/src/draive/lmm/call.py b/src/draive/lmm/call.py index a8c0587..59deab3 100644 --- a/src/draive/lmm/call.py +++ b/src/draive/lmm/call.py @@ -1,81 +1,67 @@ -from collections.abc import Callable +from collections.abc import Sequence from typing import Any, Literal, overload -from draive.lmm.completion import LMMCompletionStream -from draive.lmm.message import ( - LMMMessage, - LMMStreamingUpdate, -) +from draive.lmm.invocation import LMMOutputStream from draive.lmm.state import LMM +from draive.parameters import ToolSpecification from draive.scope import ctx -from draive.tools import Toolbox +from draive.types import LMMContextElement, LMMOutput __all__ = [ - "lmm_completion", + "lmm_invocation", ] @overload -async def lmm_completion( +async def lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", stream: Literal[True], **extra: Any, -) -> LMMCompletionStream: ... +) -> LMMOutputStream: ... @overload -async def lmm_completion( +async def lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, - stream: Callable[[LMMStreamingUpdate], None], + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: Literal[False] = False, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutput: ... @overload -async def lmm_completion( +async def lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Literal[False] = False, + stream: bool, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutputStream | LMMOutput: ... -async def lmm_completion( +async def lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Callable[[LMMStreamingUpdate], None] | bool = False, + stream: bool = False, **extra: Any, -) -> LMMCompletionStream | LMMMessage: - match stream: - case False: - return await ctx.state(LMM).completion( - context=context, - tools=tools, - output=output, - stream=False, - **extra, - ) - case True: - return await ctx.state(LMM).completion( - context=context, - tools=tools, - output=output, - stream=True, - **extra, - ) - case progress: - return await ctx.state(LMM).completion( - context=context, - tools=tools, - output=output, - stream=progress, - **extra, - ) +) -> LMMOutputStream | LMMOutput: + return await ctx.state(LMM).invocation( + context=context, + require_tool=require_tool, + tools=tools, + output=output, + stream=stream, + **extra, + ) diff --git a/src/draive/lmm/completion.py b/src/draive/lmm/completion.py deleted file mode 100644 index 23d5164..0000000 --- a/src/draive/lmm/completion.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Callable -from typing import Any, Literal, Protocol, Self, overload, runtime_checkable - -from draive.lmm.message import ( - LMMMessage, - LMMStreamingUpdate, -) -from draive.tools import Toolbox - -__all__ = [ - "LMMCompletion", - "LMMCompletionStream", -] - - -class LMMCompletionStream(Protocol): - def __aiter__(self) -> Self: ... - - async def __anext__(self) -> LMMStreamingUpdate: ... - - -@runtime_checkable -class LMMCompletion(Protocol): - @overload - async def __call__( - self, - *, - context: list[LMMMessage], - tools: Toolbox | None, - output: Literal["text", "json"], - stream: Literal[True], - **extra: Any, - ) -> LMMCompletionStream: ... - - @overload - async def __call__( - self, - *, - context: list[LMMMessage], - tools: Toolbox | None, - output: Literal["text", "json"], - stream: Callable[[LMMStreamingUpdate], None], - **extra: Any, - ) -> LMMMessage: ... - - @overload - async def __call__( - self, - *, - context: list[LMMMessage], - tools: Toolbox | None, - output: Literal["text", "json"], - stream: Literal[False], - **extra: Any, - ) -> LMMMessage: ... - - async def __call__( - self, - *, - context: list[LMMMessage], - tools: Toolbox | None = None, - output: Literal["text", "json"] = "text", - stream: Callable[[LMMStreamingUpdate], None] | bool = False, - **extra: Any, - ) -> LMMCompletionStream | LMMMessage: ... diff --git a/src/draive/tools/errors.py b/src/draive/lmm/errors.py similarity index 100% rename from src/draive/tools/errors.py rename to src/draive/lmm/errors.py diff --git a/src/draive/lmm/invocation.py b/src/draive/lmm/invocation.py new file mode 100644 index 0000000..befa613 --- /dev/null +++ b/src/draive/lmm/invocation.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from typing import ( + Any, + Literal, + Protocol, + overload, + runtime_checkable, +) + +from draive.parameters import ToolSpecification +from draive.types import LMMContextElement, LMMOutput, LMMOutputStream + +__all__ = [ + "LMMInvocation", +] + + +@runtime_checkable +class LMMInvocation(Protocol): + @overload + async def __call__( + self, + *, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: Literal[True], + **extra: Any, + ) -> LMMOutputStream: ... + + @overload + async def __call__( + self, + *, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: Literal[False] = False, + **extra: Any, + ) -> LMMOutput: ... + + @overload + async def __call__( + self, + *, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: bool, + **extra: Any, + ) -> LMMOutputStream | LMMOutput: ... + + async def __call__( # noqa: PLR0913 + self, + *, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: bool = False, + **extra: Any, + ) -> LMMOutputStream | LMMOutput: ... diff --git a/src/draive/lmm/message.py b/src/draive/lmm/message.py deleted file mode 100644 index deae97a..0000000 --- a/src/draive/lmm/message.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Literal - -from draive.tools import ToolCallUpdate -from draive.types import ( - Model, - MultimodalContent, - has_media, - multimodal_content_string, -) - -__all__ = [ - "LMMMessage", - "LMMStreamingUpdate", -] - - -class LMMMessage(Model): - role: Literal["system", "assistant", "user"] - content: MultimodalContent - - @property - def has_media(self) -> bool: - return has_media(self.content) - - @property - def content_string(self) -> str: - return multimodal_content_string(self.content) - - -LMMStreamingUpdate = LMMMessage | ToolCallUpdate diff --git a/src/draive/lmm/state.py b/src/draive/lmm/state.py index 34ea365..c9de9ef 100644 --- a/src/draive/lmm/state.py +++ b/src/draive/lmm/state.py @@ -1,10 +1,63 @@ -from draive.lmm.completion import LMMCompletion -from draive.types import State +from collections.abc import Callable +from typing import Literal + +from draive.lmm.invocation import LMMInvocation +from draive.types import Model, State, ToolCallStatus __all__: list[str] = [ "LMM", + "ToolCallContext", + "ToolStatusStream", ] class LMM(State): - completion: LMMCompletion + invocation: LMMInvocation + + +class ToolStatusStream(State): + send: Callable[[ToolCallStatus], None] | None = None + + +class ToolCallContext(State): + call_id: str + tool: str + send_status: Callable[[ToolCallStatus], None] + + def report( + self, + status: Literal[ + "STARTED", + "RUNNING", + "FINISHED", + "FAILED", + ], + /, + content: Model | dict[str, object] | None = None, + ) -> None: + call_status: ToolCallStatus + match content: + case None: + call_status = ToolCallStatus( + identifier=self.call_id, + tool=self.tool, + status=status, + ) + + case Model() as model: + call_status = ToolCallStatus( + identifier=self.call_id, + tool=self.tool, + status=status, + content=model.as_dict(), + ) + + case content: + call_status = ToolCallStatus( + identifier=self.call_id, + tool=self.tool, + status=status, + content=content, + ) + + self.send_status(call_status) diff --git a/src/draive/tools/tool.py b/src/draive/lmm/tool.py similarity index 56% rename from src/draive/tools/tool.py rename to src/draive/lmm/tool.py index 51fc3b2..8175d6a 100644 --- a/src/draive/tools/tool.py +++ b/src/draive/lmm/tool.py @@ -2,30 +2,28 @@ from typing import ( Any, Protocol, - cast, final, overload, ) from uuid import uuid4 from draive.helpers import freeze +from draive.lmm.state import ToolCallContext, ToolStatusStream from draive.metrics import ArgumentsTrace, ResultTrace from draive.parameters import Function, ParametrizedTool from draive.scope import ctx -from draive.tools.errors import ToolException -from draive.tools.state import ToolCallContext, ToolsUpdatesContext -from draive.tools.update import ToolCallUpdate -from draive.types import MultimodalContent, MultimodalContentItem +from draive.types import MultimodalContent, MultimodalContentElement __all__ = [ + "AnyTool", "tool", "Tool", - "ToolAvailability", + "ToolAvailabilityCheck", ] -class ToolAvailability(Protocol): - def __call__(self) -> bool: ... +class ToolAvailabilityCheck(Protocol): + def __call__(self) -> None: ... @final @@ -37,121 +35,103 @@ def __init__( # noqa: PLR0913 *, function: Function[Args, Coroutine[None, None, Result]], description: str | None = None, - availability: ToolAvailability | None = None, - format_result: Callable[[Result], MultimodalContent], - require_direct_result: bool = False, + availability_check: ToolAvailabilityCheck | None = None, + format_result: Callable[[Result], MultimodalContent | MultimodalContentElement], + format_failure: Callable[[Exception], MultimodalContent | MultimodalContentElement], + direct_result: bool = False, ) -> None: super().__init__( name=name, function=function, description=description, ) - self._require_direct_result: bool = require_direct_result - self._availability: ToolAvailability = availability or ( - lambda: True # available by default + self._direct_result: bool = direct_result + self._check_availability: ToolAvailabilityCheck = availability_check or ( + lambda: None # available by default + ) + self.format_result: Callable[[Result], MultimodalContent | MultimodalContentElement] = ( + format_result + ) + self.format_failure: Callable[[Exception], MultimodalContent | MultimodalContentElement] = ( + format_failure ) - self.format_result: Callable[[Result], MultimodalContent] = format_result freeze(self) @property def available(self) -> bool: - return self._availability() + try: + self._check_availability() + return True + + except Exception: + return False @property def requires_direct_result(self) -> bool: - return self._require_direct_result + return self._direct_result - async def call( + # call from toolbox + async def _toolbox_call( self, call_id: str, /, - *args: Args.args, - **kwargs: Args.kwargs, + arguments: dict[str, Any], ) -> MultimodalContent: - return self.format_result( - await self._wrapped_call( - call_id, - *args, - **kwargs, - ) - ) - - async def __call__( - self, - *args: Args.args, - **kwargs: Args.kwargs, - ) -> Result: - return await self._wrapped_call( - uuid4().hex, - *args, - **kwargs, - ) - - async def _wrapped_call( - self, - call_id: str, - /, - *args: Args.args, - **kwargs: Args.kwargs, - ) -> Result: call_context: ToolCallContext = ToolCallContext( call_id=call_id, tool=self.name, + send_status=ctx.state(ToolStatusStream).send or (lambda _: None), ) - send_update: Callable[[ToolCallUpdate], None] = ctx.state( - ToolsUpdatesContext - ).send_update or (lambda _: None) - with ctx.nested( self.name, state=[call_context], - metrics=[ArgumentsTrace.of(*args, call_id=call_context.call_id, **kwargs)], + metrics=[ArgumentsTrace.of(**arguments)], ): - try: - send_update( # notify on start - ToolCallUpdate( - call_id=call_context.call_id, - tool=call_context.tool, - status="STARTED", - content=None, - ) - ) - if not self.available: - raise ToolException("Attempting to use unavailable tool", self.name) - - result: Result = await super().__call__( - *args, - **kwargs, - ) + call_context.report("STARTED") + try: + self._check_availability() + result: Result = await super().__call__(**arguments) # pyright: ignore[reportCallIssue] ctx.record(ResultTrace.of(result)) - send_update( # notify on finish - ToolCallUpdate( - call_id=call_context.call_id, - tool=call_context.tool, - status="FINISHED", - content=None, - ) - ) - return result + call_context.report("FINISHED") + + return MultimodalContent.of(self.format_result(result)) except Exception as exc: - send_update( # notify on fail - ToolCallUpdate( - call_id=call_context.call_id, - tool=call_context.tool, - status="FAILED", - content=None, - ) + call_context.report("FAILED") + # do not blow up on tool call, return an error content instead + return MultimodalContent.of(self.format_failure(exc)) + + # regular call when using as a function + async def __call__( + self, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Result: + with ctx.nested( + self.name, + state=[ + ToolCallContext( + call_id=uuid4().hex, + tool=self.name, + send_status=lambda _: None, ) - raise ToolException( - "Tool call %s of %s failed due to an error: %s", - call_context.call_id, - call_context.tool, - exc, - ) from exc + ], + metrics=[ArgumentsTrace.of(*args, **kwargs)], + ): + result: Result = await super().__call__( + *args, + **kwargs, + ) + + ctx.record(ResultTrace.of(result)) + + return result + + +AnyTool = Tool[Any, Any] @overload @@ -182,8 +162,10 @@ def tool[**Args, Result]( *, name: str | None = None, description: str | None = None, - availability: ToolAvailability | None = None, - format_result: Callable[[Result], MultimodalContent] | None = None, + availability_check: ToolAvailabilityCheck | None = None, + format_result: Callable[[Result], MultimodalContent | MultimodalContentElement] | None = None, + format_failure: Callable[[Exception], MultimodalContent | MultimodalContentElement] + | None = None, direct_result: bool = False, ) -> Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]]: """ @@ -202,13 +184,17 @@ def tool[**Args, Result]( description to be used in a tool specification. Allows to present the tool behavior to the external system. Default is empty. - availability: ToolAvailability + availability_check: ToolAvailabilityCheck function used to verify availability of the tool in given context. It can be used to check permissions or occurrence of a specific state to allow its usage. + Provided function should raise an Exception when the tool should not be available. Default is always available. format_result: Callable[[Result], MultimodalContent] function converting tool result to MultimodalContent. It is used to format the result for model processing. Default implementation converts the result to string if needed. + format_failure: Callable[[Exception], MultimodalContent] + function converting tool call exception to a fallback MultimodalContent. + Default implementation return "ERROR" string and logs the exception. direct_result: bool controls if tool result should break the ongoing processing and be the direct result of it. Note that during concurrent execution of multiple tools the call/result order defines @@ -227,8 +213,10 @@ def tool[**Args, Result]( # noqa: PLR0913 *, name: str | None = None, description: str | None = None, - availability: ToolAvailability | None = None, - format_result: Callable[[Result], MultimodalContent] | None = None, + availability_check: ToolAvailabilityCheck | None = None, + format_result: Callable[[Result], MultimodalContent | MultimodalContentElement] | None = None, + format_failure: Callable[[Exception], MultimodalContent | MultimodalContentElement] + | None = None, direct_result: bool = False, ) -> ( Callable[[Function[Args, Coroutine[None, None, Result]]], Tool[Args, Result]] @@ -241,9 +229,10 @@ def wrap( name=name or function.__name__, description=description, function=function, - availability=availability, + availability_check=availability_check, format_result=format_result or _default_result_format, - require_direct_result=direct_result, + format_failure=format_failure or _default_failure_result, + direct_result=direct_result, ) if function := function: @@ -253,14 +242,17 @@ def wrap( def _default_result_format(result: Any) -> MultimodalContent: - if isinstance(result, MultimodalContentItem): - return result + match result: + case MultimodalContent() as content: + return content - elif isinstance(result, tuple) and all( - isinstance(element, MultimodalContentItem) - for element in result # pyright: ignore[reportUnknownVariableType] - ): - return cast(MultimodalContent, result) + case element if isinstance(element, MultimodalContentElement): + return MultimodalContent.of(element) - else: - return str(result) # pyright: ignore[reportUnknownArgumentType] + case other: + return MultimodalContent.of(str(other)) + + +def _default_failure_result(exception: Exception) -> MultimodalContent: + ctx.log_error("Tool call failure", exception=exception) + return MultimodalContent.of("ERROR") diff --git a/src/draive/lmm/toolbox.py b/src/draive/lmm/toolbox.py new file mode 100644 index 0000000..9ce795b --- /dev/null +++ b/src/draive/lmm/toolbox.py @@ -0,0 +1,157 @@ +from asyncio import FIRST_COMPLETED, Task, gather, wait +from collections.abc import AsyncGenerator, Callable, Coroutine +from typing import Any, Literal, final + +from draive.helpers import freeze +from draive.lmm.errors import ToolException +from draive.lmm.state import ToolStatusStream +from draive.lmm.tool import AnyTool +from draive.parameters import ToolSpecification +from draive.scope import ctx +from draive.types import ( + LMMToolRequest, + LMMToolRequests, + LMMToolResponse, + MultimodalContent, + ToolCallStatus, +) +from draive.utils import AsyncStreamTask + +__all__ = [ + "Toolbox", +] + + +@final +class Toolbox: + def __init__( + self, + *tools: AnyTool, + suggest: AnyTool | Literal[True] | None = None, + recursive_calls_limit: int | None = None, + ) -> None: + self._tools: dict[str, AnyTool] = {tool.name: tool for tool in tools} + self.recursion_limit: int = recursive_calls_limit or 1 + self.suggest_tools: bool + self._suggested_tool: AnyTool | None + match suggest: + case None: + self.suggest_tools = False + self._suggested_tool = None + case True: + self.suggest_tools = True if self._tools else False + self._suggested_tool = None + case tool: + self.suggest_tools = True + self._suggested_tool = tool + self._tools[tool.name] = tool + + freeze(self) + + @property + def call_range(self) -> range: + return range(0, self.recursion_limit + 1) + + def tool_suggestion( + self, + recursion_level: int = 0, + ) -> ToolSpecification | bool: + if recursion_level != 0: + return False # suggest tools only for the first call + + elif self._suggested_tool is not None: + return self._suggested_tool.specification if self._suggested_tool.available else False + + else: + return self.suggest_tools + + def available_tools( + self, + recursion_level: int = 0, + ) -> list[ToolSpecification]: + if recursion_level <= self.recursion_limit: + return [tool.specification for tool in self._tools.values() if tool.available] + else: + ctx.log_warning("Reached tool calls recursion limit, ignoring tools...") + return [] # disable tools when reached recursion limit + + async def call_tool( + self, + name: str, + /, + call_id: str, + arguments: dict[str, Any], + ) -> MultimodalContent: + if tool := self._tools.get(name): + return await tool._toolbox_call( # pyright: ignore[reportPrivateUsage] + call_id, + arguments=arguments, + ) + + else: + raise ToolException("Requested tool is not defined", name) + + async def respond( + self, + requests: LMMToolRequests, + /, + ) -> list[LMMToolResponse]: + return await gather( + *[self._respond(request) for request in requests.requests], + return_exceptions=False, + ) + + async def _respond( + self, + request: LMMToolRequest, + /, + ) -> LMMToolResponse: + if tool := self._tools.get(request.tool): + return LMMToolResponse( + identifier=request.identifier, + tool=request.tool, + content=await tool._toolbox_call( # pyright: ignore[reportPrivateUsage] + request.identifier, + arguments=request.arguments or {}, + ), + direct=tool.requires_direct_result, + ) + + else: + # log error and provide fallback result to avoid blowing out the execution + ctx.log_error("Requested tool (%s) is not defined", request.tool) + return LMMToolResponse( + identifier=request.identifier, + tool=request.tool, + content=MultimodalContent.of("ERROR"), + direct=False, + ) + + async def stream( + self, + requests: LMMToolRequests, + /, + ) -> AsyncGenerator[LMMToolResponse | ToolCallStatus, None]: + async for element in AsyncStreamTask(job=self._stream_task(requests)): + yield element + + def _stream_task( + self, + requests: LMMToolRequests, + /, + ) -> Callable[ + [Callable[[LMMToolResponse | ToolCallStatus], None]], Coroutine[None, None, None] + ]: + async def tools_stream( + send: Callable[[LMMToolResponse | ToolCallStatus], None], + ) -> None: + with ctx.updated(ToolStatusStream(send=send)): + pending_tasks: set[Task[LMMToolResponse]] = { + ctx.spawn_subtask(self._respond, request) for request in requests.requests + } + while pending_tasks: + done, pending_tasks = await wait(pending_tasks, return_when=FIRST_COMPLETED) + for task in done: + send(task.result()) + + return tools_stream diff --git a/src/draive/metrics/function.py b/src/draive/metrics/function.py index 3429fcb..d73ea16 100644 --- a/src/draive/metrics/function.py +++ b/src/draive/metrics/function.py @@ -70,8 +70,8 @@ def __add__(self, other: Self) -> Self: exceptions = [] exception_messages = [] for exception in (*self.exception.exceptions, other.exception): # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - exception_messages.append(exception) # pyright: ignore[reportArgumentType] - exceptions.append(f"{type(exception).__qualname__}:{exception}") # pyright: ignore[reportArgumentType, reportUnknownMemberType, reportUnknownArgumentType] + exceptions.append(exception) # pyright: ignore[reportUnknownArgumentType] + exception_messages.append(f"{type(exception).__qualname__}:{exception}") # pyright: ignore[reportUnknownArgumentType] else: exceptions = (self.exception, other.exception) diff --git a/src/draive/metrics/trace.py b/src/draive/metrics/trace.py index 0646f27..f39f4c6 100644 --- a/src/draive/metrics/trace.py +++ b/src/draive/metrics/trace.py @@ -33,7 +33,7 @@ def __init__( type(metric): metric for metric in metrics or [] } self._nested_traces: list[MetricsTrace] = [] - self.log_info("%s started", self) + self.log_info("started...") # - STATE - @@ -196,8 +196,7 @@ def exit(self) -> None: self._end = monotonic() self.log_info( - "%s finished after %.2fs", - self, + "...finished after %.2fs", self._end - self._start, ) diff --git a/src/draive/mistral/__init__.py b/src/draive/mistral/__init__.py index 0be4fca..24a1e99 100644 --- a/src/draive/mistral/__init__.py +++ b/src/draive/mistral/__init__.py @@ -2,11 +2,11 @@ from draive.mistral.config import MistralChatConfig, MistralEmbeddingConfig from draive.mistral.embedding import mistral_embed_text from draive.mistral.errors import MistralException -from draive.mistral.lmm import mistral_lmm_completion +from draive.mistral.lmm import mistral_lmm_invocation __all__ = [ "mistral_embed_text", - "mistral_lmm_completion", + "mistral_lmm_invocation", "MistralChatConfig", "MistralClient", "MistralEmbeddingConfig", diff --git a/src/draive/mistral/chat_response.py b/src/draive/mistral/chat_response.py deleted file mode 100644 index 0954a0b..0000000 --- a/src/draive/mistral/chat_response.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import cast - -from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage -from draive.mistral.chat_tools import ( - _execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage] -) -from draive.mistral.client import MistralClient -from draive.mistral.config import MistralChatConfig -from draive.mistral.errors import MistralException -from draive.mistral.models import ChatCompletionResponse, ChatMessage, ChatMessageResponse -from draive.scope import ctx -from draive.tools import Toolbox - -__all__ = [ - "_chat_response", -] - - -async def _chat_response( # noqa: C901, PLR0912 - *, - client: MistralClient, - config: MistralChatConfig, - messages: list[ChatMessage], - tools: Toolbox, - recursion_level: int = 0, -) -> str: - with ctx.nested( - "chat_response", - metrics=[ArgumentsTrace.of(messages=messages.copy())], - ): - completion: ChatCompletionResponse - - if recursion_level == config.recursion_limit: - ctx.log_warning("Reaching limit of recursive Mistral calls, ignoring tools...") - completion = await client.chat_completion( - config=config, - messages=messages, - ) - - elif recursion_level != 0: # suggest/require tool call only initially - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[dict[str, object]], - tools.available_tools, - ), - ) - - elif suggested_tool := tools.suggested_tool: - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[dict[str, object]], - [suggested_tool], - ), - suggest_tools=True, - ) - - else: - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[dict[str, object]], - tools.available_tools, - ), - suggest_tools=tools.suggest_tools, - ) - - if usage := completion.usage: - ctx.record( - TokenUsage.for_model( - config.model, - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - ), - ) - - if not completion.choices: - raise MistralException("Invalid Mistral completion - missing messages!", completion) - - completion_message: ChatMessageResponse = completion.choices[0].message - - if (tool_calls := completion_message.tool_calls) and (tools := tools): - ctx.record(ResultTrace.of(tool_calls)) - - tools_result: list[ChatMessage] | str = await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - - if isinstance(tools_result, str): - return tools_result - else: - messages.extend(tools_result) - - elif message := completion_message.content: - ctx.record(ResultTrace.of(message)) - match message: - case str(content): - return content - - # API docs say that it can be only a string in response - # however library model allows list as well - case other: - return str(other) - - else: - raise MistralException("Invalid Mistral completion", completion) - - # recursion outside of context - if recursion_level >= config.recursion_limit: - raise MistralException("Reached limit of recursive calls of %d", config.recursion_limit) - - return await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools, - recursion_level=recursion_level + 1, - ) diff --git a/src/draive/mistral/chat_stream.py b/src/draive/mistral/chat_stream.py deleted file mode 100644 index cc192fc..0000000 --- a/src/draive/mistral/chat_stream.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Callable - -from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] -from draive.mistral.client import MistralClient -from draive.mistral.config import MistralChatConfig -from draive.mistral.models import ChatMessage -from draive.scope import ctx -from draive.tools import Toolbox, ToolCallUpdate - -__all__ = [ - "_chat_stream", -] - - -async def _chat_stream( # noqa: PLR0913 - *, - client: MistralClient, - config: MistralChatConfig, - messages: list[ChatMessage], - tools: Toolbox, - send_update: Callable[[ToolCallUpdate | str], None], - recursion_level: int = 0, -) -> str: - ctx.log_warning("Mistral streaming api is not supported yet, using regular response...") - message: str = await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools, - recursion_level=recursion_level, - ) - send_update(message) - return message diff --git a/src/draive/mistral/chat_tools.py b/src/draive/mistral/chat_tools.py deleted file mode 100644 index 5dcd044..0000000 --- a/src/draive/mistral/chat_tools.py +++ /dev/null @@ -1,132 +0,0 @@ -import json -from asyncio import gather -from collections.abc import Awaitable -from typing import Any, Literal, overload - -from draive.mistral.models import ChatMessage, ToolCallResponse -from draive.tools import Toolbox - -__all__ = [ - "_execute_chat_tool_calls", -] - - -async def _execute_chat_tool_calls( - *, - tool_calls: list[ToolCallResponse], - tools: Toolbox, -) -> list[ChatMessage] | str: - direct_result: Awaitable[str] | None = None - tool_call_results: list[Awaitable[ChatMessage]] = [] - for call in tool_calls: - # use only the first "direct result tool" requested, can't return more than one anyways - # despite of that all tools will be called to ensure that all desired actions were executed - if direct_result is None and tools.requires_direct_result(tool_name=call.function.name): - direct_result = _execute_chat_tool_call( - call_id=call.id, - name=call.function.name, - arguments=call.function.arguments, - tools=tools, - message_result=False, - ) - else: - tool_call_results.append( - _execute_chat_tool_call( - call_id=call.id, - name=call.function.name, - arguments=call.function.arguments, - tools=tools, - message_result=True, - ) - ) - if direct_result is not None: - results: tuple[str, ...] = await gather( - direct_result, - *tool_call_results, - return_exceptions=False, - ) - return results[0] # return only the requested direct result - else: - return [ - ChatMessage( - role="assistant", - content="", - tool_calls=[ - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments - if isinstance(call.function.arguments, str) - else json.dumps(call.function.arguments), - }, - } - for call in tool_calls - ], - ), - *await gather( - *tool_call_results, - return_exceptions=False, - ), - ] - - -@overload -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: dict[str, Any] | str, - tools: Toolbox, - message_result: Literal[True], -) -> ChatMessage: ... - - -@overload -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: dict[str, Any] | str, - tools: Toolbox, - message_result: Literal[False], -) -> str: ... - - -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: dict[str, Any] | str, - tools: Toolbox, - message_result: bool, -) -> ChatMessage | str: - try: # make sure that tool error won't blow up whole chain - result: str = str( - await tools.call_tool( - name, - call_id=call_id, - arguments=arguments, - ) - ) - if message_result: - return ChatMessage( - role="tool", - name=name, - content=str(result), - ) - else: - return result - - # error should be already logged by ScopeContext - except BaseException as exc: - if message_result: - return ChatMessage( - role="tool", - name=name, - content="Error", - ) - - else: # TODO: think about allowing the error chat message - raise exc diff --git a/src/draive/mistral/client.py b/src/draive/mistral/client.py index c971a4c..4986771 100644 --- a/src/draive/mistral/client.py +++ b/src/draive/mistral/client.py @@ -79,6 +79,7 @@ async def chat_completion( # noqa: PLR0913 ) -> AsyncIterable[ChatCompletionStreamResponse] | ChatCompletionResponse: if stream: raise NotImplementedError("Mistral streaming is not supported yet") + else: return await self._create_chat_completion( messages=messages, diff --git a/src/draive/mistral/config.py b/src/draive/mistral/config.py index f627935..14d936f 100644 --- a/src/draive/mistral/config.py +++ b/src/draive/mistral/config.py @@ -35,7 +35,6 @@ class MistralChatConfig(Model): max_tokens: int | None = None response_format: ResponseFormat | None = None timeout: float | None = None - recursion_limit: int = 4 class MistralEmbeddingConfig(Model): diff --git a/src/draive/mistral/lmm.py b/src/draive/mistral/lmm.py index f3b8c7a..a99c3de 100644 --- a/src/draive/mistral/lmm.py +++ b/src/draive/mistral/lmm.py @@ -1,221 +1,274 @@ -from collections.abc import Callable -from typing import Any, Literal, overload +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Literal, cast, overload -from draive.lmm import LMMCompletionStream, LMMMessage, LMMStreamingUpdate -from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] -from draive.mistral.chat_stream import _chat_stream # pyright: ignore[reportPrivateUsage] +from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage from draive.mistral.client import MistralClient from draive.mistral.config import MistralChatConfig -from draive.mistral.models import ChatMessage +from draive.mistral.errors import MistralException +from draive.mistral.models import ChatCompletionResponse, ChatMessage, ChatMessageResponse +from draive.parameters import ToolSpecification from draive.scope import ctx -from draive.tools import Toolbox, ToolCallUpdate, ToolsUpdatesContext -from draive.types import ImageBase64Content, ImageURLContent, Model -from draive.utils import AsyncStreamTask +from draive.types import ( + LMMCompletion, + LMMCompletionChunk, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMOutput, + LMMOutputStream, + LMMOutputStreamChunk, + LMMToolRequest, + LMMToolRequests, + LMMToolResponse, +) __all__ = [ - "mistral_lmm_completion", + "mistral_lmm_invocation", ] @overload -async def mistral_lmm_completion( +async def mistral_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", stream: Literal[True], **extra: Any, -) -> LMMCompletionStream: ... +) -> LMMOutputStream: ... @overload -async def mistral_lmm_completion( +async def mistral_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, - stream: Callable[[LMMStreamingUpdate], None], + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: Literal[False] = False, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutput: ... @overload -async def mistral_lmm_completion( +async def mistral_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Literal[False] = False, + stream: bool = False, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutputStream | LMMOutput: ... -async def mistral_lmm_completion( +async def mistral_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Callable[[LMMStreamingUpdate], None] | bool = False, + stream: bool = False, **extra: Any, -) -> LMMCompletionStream | LMMMessage: - client: MistralClient = ctx.dependency(MistralClient) - config: MistralChatConfig = ctx.state(MistralChatConfig).updated(**extra) - match output: - case "text": - config = config.updated(response_format={"type": "text"}) - case "json": - if tools is None: - config = config.updated(response_format={"type": "json_object"}) - else: - ctx.log_warning( - "Attempting to use Mistral in JSON mode with tools which is not supported." - " Using text mode instead..." - ) +) -> LMMOutputStream | LMMOutput: + with ctx.nested( + "mistral_lmm_completion", + metrics=[ + ArgumentsTrace.of( + context=context, + tools=tools, + require_tool=require_tool, + output=output, + stream=stream, + **extra, + ), + ], + ): + client: MistralClient = ctx.dependency(MistralClient) + config: MistralChatConfig = ctx.state(MistralChatConfig).updated(**extra) + match output: + case "text": config = config.updated(response_format={"type": "text"}) + case "json": + if tools is None: + config = config.updated(response_format={"type": "json_object"}) + else: + ctx.log_warning( + "Attempting to use Mistral in JSON mode with tools which is not supported." + " Using text mode instead..." + ) + config = config.updated(response_format={"type": "text"}) - messages: list[ChatMessage] = [_convert_message(message=message) for message in context] + messages: list[ChatMessage] = [ + _convert_context_element(element=element) for element in context + ] - match stream: - case False: - with ctx.nested("mistral_lmm_completion", metrics=[config]): - message: str = await _chat_response( + if stream: + return ctx.stream( + generator=_chat_completion_stream( client=client, config=config, messages=messages, - tools=tools or Toolbox(), - ) - if tools and output == "json": - # workaround for json mode with tools - # most common mistral mistake is to not escape newlines - # we can't fix it easily - code below removes all newlines - # which may cause missing newlines in the json content - message = message.replace("\n", "") - return LMMMessage( - role="assistant", - content=message, - ) + tools=tools, + require_tool=require_tool, + ), + ) - case True: - - async def stream_task( - streaming_update: Callable[[LMMStreamingUpdate], None], - ) -> None: - with ctx.nested( - "mistral_lmm_completion", - state=[ToolsUpdatesContext(send_update=streaming_update)], - metrics=[config], - ): - - def send_update(update: ToolCallUpdate | str) -> None: - if isinstance(update, str): - streaming_update( - LMMMessage( - role="assistant", - content=update, - ) - ) - else: - streaming_update(update) - - await _chat_stream( - client=client, - config=config, - messages=messages, - tools=tools or Toolbox(), - send_update=send_update, - ) + else: + return await _chat_completion( + client=client, + config=config, + messages=messages, + tools=tools, + require_tool=require_tool, + ) - return AsyncStreamTask(job=stream_task) - case streaming_update: +def _convert_context_element( + element: LMMContextElement, +) -> ChatMessage: + match element: + case LMMInstruction() as instruction: + return ChatMessage( + role="system", + content=instruction.content, + ) - def send_update(update: ToolCallUpdate | str) -> None: - if isinstance(update, str): - streaming_update( - LMMMessage( - role="assistant", - content=update, - ) - ) - else: - streaming_update(update) - - with ctx.nested( - "mistral_lmm_completion", - state=[ToolsUpdatesContext(send_update=streaming_update)], - metrics=[config], - ): - return LMMMessage( - role="assistant", - content=await _chat_stream( - client=client, - config=config, - messages=messages, - tools=tools or Toolbox(), - send_update=send_update, - ), - ) + case LMMInput() as input: + return ChatMessage( + role="user", + content=input.content.as_string(), + ) + case LMMCompletion() as completion: + return ChatMessage( + role="assistant", + content=completion.content.as_string(), + ) -def _convert_message( # noqa: PLR0912, C901, PLR0911 - message: LMMMessage, -) -> ChatMessage: - match message.role: - case "user": - if isinstance(message.content, str): - return ChatMessage( - role="user", - content=message.content, - ) - elif isinstance(message.content, ImageURLContent): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, ImageBase64Content): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, Model): - return ChatMessage( - role="user", - content=str(message.content), - ) - else: - content_parts: list[str] = [] - for part in message.content: - if isinstance(part, str): - content_parts.append(part) - elif isinstance(part, ImageURLContent): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, ImageBase64Content): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, Model): - content_parts.append(str(message.content)) - else: - raise ValueError("Unsupported message content", message) - return ChatMessage( - role="user", - content=content_parts, - ) + case LMMToolRequests() as tool_requests: + return ChatMessage( + role="assistant", + content="", + tool_calls=[ + { + "id": request.identifier, + "type": "function", + "function": { + "name": request.tool, + "arguments": json.dumps(request.arguments), + }, + } + for request in tool_requests.requests + ], + ) - case "assistant": - if isinstance(message.content, str): - return ChatMessage( - role="assistant", - content=message.content, - ) - elif isinstance(message.content, Model): - return ChatMessage( - role="assistant", - content=str(message.content), - ) - else: - raise ValueError("Invalid assistant message", message) - - case "system": - if isinstance(message.content, str): - return ChatMessage( - role="system", - content=message.content, - ) - elif isinstance(message.content, Model): - return ChatMessage( - role="system", - content=str(message.content), + case LMMToolResponse() as tool_response: + return ChatMessage( + role="tool", + name=tool_response.tool, + content=tool_response.content.as_string(), + ) + + +async def _chat_completion( + *, + client: MistralClient, + config: MistralChatConfig, + messages: list[ChatMessage], + tools: Sequence[ToolSpecification] | None, + require_tool: ToolSpecification | bool, +) -> LMMOutput: + completion: ChatCompletionResponse + match require_tool: + case bool(required): + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[dict[str, object]], + tools, + ), + suggest_tools=required, + ) + + case ToolSpecification() as tool: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[dict[str, object]], + [tool], # mistral can't be suggested with concrete tool + ), + suggest_tools=True, + ) + + if usage := completion.usage: + ctx.record( + TokenUsage.for_model( + config.model, + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + ), + ) + + if not completion.choices: + raise MistralException("Invalid Mistral completion - missing messages!", completion) + + completion_message: ChatMessageResponse = completion.choices[0].message + + if (tool_calls := completion_message.tool_calls) and (tools := tools): + ctx.record(ResultTrace.of(tool_calls)) + + return LMMToolRequests( + requests=[ + LMMToolRequest( + identifier=call.id, + tool=call.function.name, + arguments=json.loads(call.function.arguments) + if isinstance(call.function.arguments, str) + else call.function.arguments, ) - else: - raise ValueError("Invalid system message", message) + for call in tool_calls + ] + ) + + elif message := completion_message.content: + ctx.record(ResultTrace.of(message)) + match message: + case str(content): + return LMMCompletion.of(content) + + # API docs say that it can be only a string in response + # however library model allows list as well + case other: + return LMMCompletion.of(*other) + + else: + raise MistralException("Invalid Mistral completion", completion) + + +async def _chat_completion_stream( + *, + client: MistralClient, + config: MistralChatConfig, + messages: list[ChatMessage], + tools: Sequence[ToolSpecification] | None, + require_tool: ToolSpecification | bool, +) -> AsyncGenerator[LMMOutputStreamChunk, None]: + ctx.log_warning("Mistral streaming api is not supported yet, using regular response...") + output: LMMOutput = await _chat_completion( + client=client, + config=config, + messages=messages, + tools=tools, + require_tool=require_tool, + ) + match output: + case LMMCompletion() as completion: + yield LMMCompletionChunk.of(completion.content) + + case other: + yield other diff --git a/src/draive/mistral/models.py b/src/draive/mistral/models.py index ab839e0..8ce0379 100644 --- a/src/draive/mistral/models.py +++ b/src/draive/mistral/models.py @@ -6,11 +6,11 @@ "UsageInfo", "EmbeddingObject", "EmbeddingResponse", - "FunctionCall", - "ToolCall", + "ChatFunctionCall", + "ChatToolCallRequest", "ChatMessage", - "FunctionCallResponse", - "ToolCallResponse", + "ChatFunctionCallResponse", + "ChatToolCallResponse", "ChatMessageResponse", "ChatCompletionResponseChoice", "ChatCompletionResponse", @@ -40,39 +40,39 @@ class EmbeddingResponse(Model): usage: UsageInfo -class FunctionCall(TypedDict, total=False): +class ChatFunctionCall(TypedDict, total=False): name: Required[str] arguments: Required[str] -class ToolCall(TypedDict, total=False): +class ChatToolCallRequest(TypedDict, total=False): id: Required[str] type: Required[Literal["function"]] - function: Required[FunctionCall] + function: Required[ChatFunctionCall] class ChatMessage(TypedDict, total=False): role: Required[str] content: Required[str | list[str]] name: NotRequired[str] - tool_calls: NotRequired[list[ToolCall]] + tool_calls: NotRequired[list[ChatToolCallRequest]] -class FunctionCallResponse(Model): +class ChatFunctionCallResponse(Model): name: str arguments: dict[str, Any] | str -class ToolCallResponse(Model): +class ChatToolCallResponse(Model): id: str type: Literal["function"] - function: FunctionCallResponse + function: ChatFunctionCallResponse class ChatDeltaMessageResponse(Model): role: str | None = None content: str | None = None - tool_calls: list[ToolCallResponse] | None = None + tool_calls: list[ChatToolCallResponse] | None = None class ChatCompletionResponseStreamChoice(Model): @@ -92,7 +92,7 @@ class ChatCompletionStreamResponse(Model): class ChatMessageResponse(Model): role: str content: list[str] | str | None = None - tool_calls: list[ToolCallResponse] | None = None + tool_calls: list[ChatToolCallResponse] | None = None class ChatCompletionResponseChoice(Model): diff --git a/src/draive/openai/__init__.py b/src/draive/openai/__init__.py index 62fa433..a8ce83c 100644 --- a/src/draive/openai/__init__.py +++ b/src/draive/openai/__init__.py @@ -7,13 +7,13 @@ from draive.openai.embedding import openai_embed_text from draive.openai.errors import OpenAIException from draive.openai.images import openai_generate_image -from draive.openai.lmm import openai_lmm_completion +from draive.openai.lmm import openai_lmm_invocation from draive.openai.tokenization import openai_tokenize_text __all__ = [ "openai_embed_text", "openai_generate_image", - "openai_lmm_completion", + "openai_lmm_invocation", "openai_tokenize_text", "OpenAIChatConfig", "OpenAIClient", diff --git a/src/draive/openai/chat_response.py b/src/draive/openai/chat_response.py deleted file mode 100644 index 2ffdd8c..0000000 --- a/src/draive/openai/chat_response.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import cast - -from openai.types.chat import ( - ChatCompletion, - ChatCompletionMessage, - ChatCompletionMessageParam, - ChatCompletionToolParam, -) - -from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage -from draive.openai.chat_tools import ( - _execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage] -) -from draive.openai.client import OpenAIClient -from draive.openai.config import OpenAIChatConfig -from draive.openai.errors import OpenAIException -from draive.scope import ctx -from draive.tools import Toolbox - -__all__ = [ - "_chat_response", -] - - -async def _chat_response( - *, - client: OpenAIClient, - config: OpenAIChatConfig, - messages: list[ChatCompletionMessageParam], - tools: Toolbox, - recursion_level: int = 0, -) -> str: - with ctx.nested( - "chat_response", - metrics=[ArgumentsTrace.of(messages=messages.copy())], - ): - completion: ChatCompletion - if recursion_level == config.recursion_limit: - ctx.log_warning("Reaching limit of recursive OpenAI calls, ignoring tools...") - completion = await client.chat_completion( - config=config, - messages=messages, - ) - - elif recursion_level != 0: # suggest/require tool call only initially - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools, - ), - ) - - elif suggested_tool_name := tools.suggested_tool_name: - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools, - ), - tools_suggestion={ - "type": "function", - "function": { - "name": suggested_tool_name, - }, - }, - ) - - else: - completion = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools, - ), - tools_suggestion=tools.suggest_tools, - ) - - if usage := completion.usage: - ctx.record( - TokenUsage.for_model( - config.model, - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - ), - ) - - if not completion.choices: - raise OpenAIException("Invalid OpenAI completion - missing messages!", completion) - - completion_message: ChatCompletionMessage = completion.choices[0].message - - if (tool_calls := completion_message.tool_calls) and (tools := tools): - ctx.record(ResultTrace.of(tool_calls)) - - tools_result: list[ChatCompletionMessageParam] | str = await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - - if isinstance(tools_result, str): - return tools_result - else: - messages.extend(tools_result) - - elif message := completion_message.content: - ctx.record(ResultTrace.of(message)) - return message - - else: - raise OpenAIException("Invalid OpenAI completion", completion) - - # recursion outside of context - if recursion_level >= config.recursion_limit: - raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) - - return await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools, - recursion_level=recursion_level + 1, - ) diff --git a/src/draive/openai/chat_stream.py b/src/draive/openai/chat_stream.py deleted file mode 100644 index 1c67714..0000000 --- a/src/draive/openai/chat_stream.py +++ /dev/null @@ -1,161 +0,0 @@ -from collections.abc import Callable -from typing import cast - -from openai import AsyncStream as OpenAIAsyncStream -from openai.types.chat import ( - ChatCompletionChunk, - ChatCompletionMessageParam, - ChatCompletionMessageToolCall, - ChatCompletionNamedToolChoiceParam, - ChatCompletionToolParam, -) -from openai.types.chat.chat_completion_chunk import ChoiceDelta - -from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage -from draive.openai.chat_tools import ( - _execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage] - _flush_chat_tool_calls, # pyright: ignore[reportPrivateUsage] -) -from draive.openai.client import OpenAIClient -from draive.openai.config import OpenAIChatConfig -from draive.openai.errors import OpenAIException -from draive.scope import ctx -from draive.tools import Toolbox, ToolCallUpdate - -__all__ = [ - "_chat_stream", -] - - -async def _chat_stream( # noqa: PLR0913, C901, PLR0915, PLR0912 - *, - client: OpenAIClient, - config: OpenAIChatConfig, - messages: list[ChatCompletionMessageParam], - tools: Toolbox, - send_update: Callable[[ToolCallUpdate | str], None], - recursion_level: int = 0, -) -> str: - with ctx.nested( - "chat_stream", - metrics=[ArgumentsTrace.of(messages=messages.copy())], - ): - completion_stream: OpenAIAsyncStream[ChatCompletionChunk] - if recursion_level == config.recursion_limit: - ctx.log_warning("Reaching limit of recursive OpenAI calls, ignoring tools...") - completion_stream = await client.chat_completion( - config=config, - messages=messages, - tools=None, - stream=True, - ) - - else: - tools_suggestion: ChatCompletionNamedToolChoiceParam | bool - if recursion_level != 0: # suggest/require tool call only initially - tools_suggestion = False - - elif suggested_tool_name := tools.suggested_tool_name: - tools_suggestion = { - "type": "function", - "function": { - "name": suggested_tool_name, - }, - } - - else: - tools_suggestion = tools.suggest_tools - - completion_stream = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[ChatCompletionToolParam], - tools.available_tools if tools else [], - ), - tools_suggestion=tools_suggestion, - stream=True, - ) - - while True: # load chunks to decide what to do next - head: ChatCompletionChunk - try: - head = await anext(completion_stream) - - except StopAsyncIteration as exc: - # could not decide what to do before stream end - raise OpenAIException("Invalid OpenAI completion stream") from exc - - if not head.choices: - raise OpenAIException("Invalid OpenAI completion - missing deltas!", head) - - completion_head: ChoiceDelta = head.choices[0].delta - - if completion_head.tool_calls is not None and (tools := tools): - tool_calls: list[ChatCompletionMessageToolCall] = await _flush_chat_tool_calls( - model=config.model, # model for token usage tracking - tool_calls=completion_head.tool_calls, - completion_stream=completion_stream, - ) - ctx.record(ResultTrace.of(tool_calls)) - - tools_result: ( - list[ChatCompletionMessageParam] | str - ) = await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - - if isinstance(tools_result, str): - send_update(tools_result) - return tools_result - else: - messages.extend(tools_result) - - break # after processing tool calls continue with recursion in outer context - - elif completion_head.content is not None: - result: str = completion_head.content - if result: # provide head / first part if not empty - send_update(result) - - async for part in completion_stream: - if part.choices: # usage part does not contain choices - # we are always requesting single result - no need to take care of indices - part_text: str = part.choices[0].delta.content or "" - if not part_text: - continue # skip empty parts - result += part_text - send_update(result) - - elif usage := part.usage: # record usage if able (expected in last part) - ctx.record( - TokenUsage.for_model( - config.model, - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - ), - ) - - else: - ctx.log_warning("Unexpected OpenAI streaming part: %s", part) - continue - - ctx.record(ResultTrace.of(result)) - return result # we hav final result here - - else: - continue # iterate over the stream until can decide what to do or reach the end - - # recursion outside of context - if recursion_level >= config.recursion_limit: - raise OpenAIException("Reached limit of recursive calls of %d", config.recursion_limit) - - return await _chat_stream( - client=client, - config=config, - messages=messages, - tools=tools, - send_update=send_update, - recursion_level=recursion_level + 1, - ) diff --git a/src/draive/openai/chat_tools.py b/src/draive/openai/chat_tools.py deleted file mode 100644 index 827a564..0000000 --- a/src/draive/openai/chat_tools.py +++ /dev/null @@ -1,209 +0,0 @@ -from asyncio import gather -from collections.abc import Awaitable -from typing import Literal, cast, overload - -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionChunk, - ChatCompletionMessageParam, - ChatCompletionMessageToolCall, - ChatCompletionMessageToolCallParam, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall, -) - -from draive.metrics import TokenUsage -from draive.scope import ctx -from draive.tools import Toolbox - -__all__ = [ - "_execute_chat_tool_calls", - "_flush_chat_tool_calls", -] - - -async def _execute_chat_tool_calls( - *, - tool_calls: list[ChatCompletionMessageToolCall], - tools: Toolbox, -) -> list[ChatCompletionMessageParam] | str: - direct_result: Awaitable[str] | None = None - tool_call_params: list[ChatCompletionMessageToolCallParam] = [] - tool_call_results: list[Awaitable[ChatCompletionMessageParam]] = [] - for call in tool_calls: - # use only the first "direct result tool" requested, can't return more than one anyways - # despite of that all tools will be called to ensure that all desired actions were executed - if direct_result is None and tools.requires_direct_result(tool_name=call.function.name): - direct_result = _execute_chat_tool_call( - call_id=call.id, - name=call.function.name, - arguments=call.function.arguments, - tools=tools, - message_result=False, - ) - else: - tool_call_results.append( - _execute_chat_tool_call( - call_id=call.id, - name=call.function.name, - arguments=call.function.arguments, - tools=tools, - message_result=True, - ), - ) - tool_call_params.append( - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments, - }, - }, - ) - - if direct_result is not None: - results: tuple[str, ...] = await gather( - direct_result, - *tool_call_results, - return_exceptions=False, - ) - return results[0] # return only the requested direct result - else: - return [ - { - "role": "assistant", - "tool_calls": tool_call_params, - }, - *await gather( - *tool_call_results, - return_exceptions=False, - ), - ] - - -@overload -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: str, - tools: Toolbox, - message_result: Literal[True], -) -> ChatCompletionMessageParam: ... - - -@overload -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: str, - tools: Toolbox, - message_result: Literal[False], -) -> str: ... - - -async def _execute_chat_tool_call( - *, - call_id: str, - name: str, - arguments: str, - tools: Toolbox, - message_result: bool, -) -> ChatCompletionMessageParam | str: - try: # make sure that tool error won't blow up whole chain - result: str = str( - await tools.call_tool( - name, - call_id=call_id, - arguments=arguments, - ) - ) - if message_result: - return { - "role": "tool", - "tool_call_id": call_id, - "content": str(result), - } - else: - return result - - # error should be already logged by ScopeContext - except Exception as exc: - if message_result: - return { - "role": "tool", - "tool_call_id": call_id, - "content": "Error", - } - - else: # TODO: think about allowing the error chat message - raise exc - - -async def _flush_chat_tool_calls( # noqa: C901, PLR0912 - *, - model: str, - tool_calls: list[ChoiceDeltaToolCall], - completion_stream: AsyncStream[ChatCompletionChunk], -) -> list[ChatCompletionMessageToolCall]: - # iterate over the stream to get full list of tool calls - async for part in completion_stream: - if part.choices: # usage part does not contain choices - for call in part.choices[0].delta.tool_calls or []: - try: - tool_call: ChoiceDeltaToolCall = next( - tool_call for tool_call in tool_calls if tool_call.index == call.index - ) - - if call.id: - if tool_call.id is not None: - tool_call.id += call.id - else: - tool_call.id = call.id - else: - pass - - if call.function is None: - continue - - if tool_call.function is None: - tool_call.function = call.function - continue - - if call.function.name: - if tool_call.function.name is not None: - tool_call.function.name += call.function.name - else: - tool_call.function.name = call.function.name - else: - pass - - if call.function.arguments: - if tool_call.function.arguments is not None: - tool_call.function.arguments += call.function.arguments - else: - tool_call.function.arguments = call.function.arguments - else: - pass - - except (StopIteration, StopAsyncIteration): - tool_calls.append(call) - - elif usage := part.usage: # record usage if able (expected in last part) - ctx.record( - TokenUsage.for_model( - model, - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - ), - ) - - else: - ctx.log_warning("Unexpected OpenAI streaming part: %s", part) - continue - - # completed calls have exactly the same model - return cast(list[ChatCompletionMessageToolCall], tool_calls) diff --git a/src/draive/openai/config.py b/src/draive/openai/config.py index b0bf4d1..0283b42 100644 --- a/src/draive/openai/config.py +++ b/src/draive/openai/config.py @@ -43,7 +43,6 @@ class OpenAIChatConfig(Model): response_format: ResponseFormat | None = None vision_details: Literal["auto", "low", "high"] | None = None timeout: float | None = None - recursion_limit: int = 4 class OpenAIEmbeddingConfig(Model): diff --git a/src/draive/openai/lmm.py b/src/draive/openai/lmm.py index c7bb7ab..5015f4b 100644 --- a/src/draive/openai/lmm.py +++ b/src/draive/openai/lmm.py @@ -1,239 +1,460 @@ -from collections.abc import Callable -from typing import Any, Literal, overload +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Literal, cast, overload +from uuid import uuid4 -from openai.types.chat import ChatCompletionContentPartParam, ChatCompletionMessageParam +from openai import AsyncStream as OpenAIAsyncStream +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionContentPartParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDeltaToolCall -from draive.lmm import LMMCompletionStream, LMMMessage, LMMStreamingUpdate -from draive.openai.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] -from draive.openai.chat_stream import _chat_stream # pyright: ignore[reportPrivateUsage] +from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage from draive.openai.client import OpenAIClient from draive.openai.config import OpenAIChatConfig +from draive.openai.errors import OpenAIException +from draive.parameters import ToolSpecification from draive.scope import ctx -from draive.tools import Toolbox, ToolCallUpdate, ToolsUpdatesContext -from draive.types import ImageBase64Content, ImageURLContent, Model -from draive.utils import AsyncStreamTask +from draive.types import ( + ImageBase64Content, + ImageURLContent, + LMMCompletion, + LMMCompletionChunk, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMOutput, + LMMOutputStream, + LMMOutputStreamChunk, + LMMToolRequest, + LMMToolRequests, + LMMToolResponse, +) +from draive.types.audio import AudioBase64Content, AudioDataContent, AudioURLContent +from draive.types.images import ImageDataContent +from draive.types.multimodal import MultimodalContentElement +from draive.types.video import VideoBase64Content, VideoDataContent, VideoURLContent __all__ = [ - "openai_lmm_completion", + "openai_lmm_invocation", ] @overload -async def openai_lmm_completion( +async def openai_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", stream: Literal[True], **extra: Any, -) -> LMMCompletionStream: ... +) -> LMMOutputStream: ... @overload -async def openai_lmm_completion( +async def openai_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, - stream: Callable[[LMMStreamingUpdate], None], + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, + output: Literal["text", "json"] = "text", + stream: Literal[False] = False, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutput: ... @overload -async def openai_lmm_completion( +async def openai_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Literal[False] = False, + stream: bool = False, **extra: Any, -) -> LMMMessage: ... +) -> LMMOutputStream | LMMOutput: ... -async def openai_lmm_completion( +async def openai_lmm_invocation( *, - context: list[LMMMessage], - tools: Toolbox | None = None, + context: Sequence[LMMContextElement], + tools: Sequence[ToolSpecification] | None = None, + require_tool: ToolSpecification | bool = False, output: Literal["text", "json"] = "text", - stream: Callable[[LMMStreamingUpdate], None] | bool = False, + stream: bool = False, **extra: Any, -) -> LMMCompletionStream | LMMMessage: - client: OpenAIClient = ctx.dependency(OpenAIClient) - config: OpenAIChatConfig = ctx.state(OpenAIChatConfig).updated(**extra) - match output: - case "text": - config = config.updated(response_format={"type": "text"}) - case "json": - config = config.updated(response_format={"type": "json_object"}) - messages: list[ChatCompletionMessageParam] = [ - _convert_message(config=config, message=message) for message in context - ] - - match stream: - case False: - with ctx.nested("openai_lmm_completion", metrics=[config]): - return LMMMessage( - role="assistant", - content=await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools or Toolbox(), - ), - ) +) -> LMMOutputStream | LMMOutput: + with ctx.nested( + "openai_lmm_completion", + metrics=[ + ArgumentsTrace.of( + context=context, + tools=tools, + require_tool=require_tool, + output=output, + stream=stream, + **extra, + ), + ], + ): + client: OpenAIClient = ctx.dependency(OpenAIClient) + config: OpenAIChatConfig = ctx.state(OpenAIChatConfig).updated(**extra) + match output: + case "text": + config = config.updated(response_format={"type": "text"}) + case "json": + config = config.updated(response_format={"type": "json_object"}) - case True: - - async def stream_task( - streaming_update: Callable[[LMMStreamingUpdate], None], - ) -> None: - with ctx.nested( - "openai_lmm_completion", - state=[ToolsUpdatesContext(send_update=streaming_update)], - metrics=[config], - ): - - def send_update(update: ToolCallUpdate | str) -> None: - if isinstance(update, str): - streaming_update( - LMMMessage( - role="assistant", - content=update, - ) - ) - else: - streaming_update(update) + messages: list[ChatCompletionMessageParam] = [ + _convert_context_element(config=config, element=message) for message in context + ] + + if stream: + return ctx.stream( + generator=_chat_completion_stream( + client=client, + config=config, + messages=messages, + tools=tools, + require_tool=require_tool, + ), + ) + else: + return await _chat_completion( + client=client, + config=config, + messages=messages, + tools=tools, + require_tool=require_tool, + ) + + +def _convert_content_element( + element: MultimodalContentElement, + config: OpenAIChatConfig, +) -> ChatCompletionContentPartParam: + match element: + case str() as string: + return { + "type": "text", + "text": string, + } + + case ImageURLContent() as image: + return { + "type": "image_url", + "image_url": { + "url": image.image_url, + "detail": config.vision_details or "auto", + }, + } + + case ImageBase64Content(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) + + case ImageDataContent(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) + + case AudioURLContent(): + # TODO: OpenAI models with audio? + raise ValueError("Unsupported message content", element) + + case AudioBase64Content(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) - await _chat_stream( - client=client, + case AudioDataContent(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) + + case VideoURLContent(): + # TODO: OpenAI models with video? + raise ValueError("Unsupported message content", element) + + case VideoBase64Content(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) + + case VideoDataContent(): + # TODO: we could upload media using openAI endpoint to have url instead + raise ValueError("Unsupported message content", element) + + +def _convert_context_element( + element: LMMContextElement, + config: OpenAIChatConfig, +) -> ChatCompletionMessageParam: + match element: + case LMMInstruction() as instruction: + return { + "role": "system", + "content": instruction.content, + } + + case LMMInput() as input: + return { + "role": "user", + "content": [ + _convert_content_element( + element=element, config=config, - messages=messages, - tools=tools or Toolbox(), - send_update=send_update, ) + for element in input.content.elements + ], + } + + case LMMCompletion() as completion: + # TODO: OpenAI models generating media? + return { + "role": "assistant", + "content": completion.content.as_string(), + } + + case LMMToolRequests() as tool_requests: + return { + "role": "assistant", + "tool_calls": [ + { + "id": request.identifier, + "type": "function", + "function": { + "name": request.tool, + "arguments": json.dumps(request.arguments), + }, + } + for request in tool_requests.requests + ], + } + + case LMMToolResponse() as tool_response: + return { + "role": "tool", + "tool_call_id": tool_response.identifier, + "content": tool_response.content.as_string(), + } + + +async def _chat_completion( + *, + client: OpenAIClient, + config: OpenAIChatConfig, + messages: list[ChatCompletionMessageParam], + tools: Sequence[ToolSpecification] | None, + require_tool: ToolSpecification | bool, +) -> LMMOutput: + completion: ChatCompletion + match require_tool: + case bool(required): + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools, + ), + tools_suggestion=required, + ) + + case ToolSpecification() as tool: + completion = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools, + ), + tools_suggestion={ + "type": "function", + "function": { + "name": tool["function"]["name"], + }, + }, + ) - return AsyncStreamTask(job=stream_task) + if usage := completion.usage: + ctx.record( + TokenUsage.for_model( + config.model, + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + ), + ) - case streaming_update: + if not completion.choices: + raise OpenAIException("Invalid OpenAI completion - missing messages!", completion) - def send_update(update: ToolCallUpdate | str) -> None: - if isinstance(update, str): - streaming_update( - LMMMessage( - role="assistant", - content=update, + completion_message: ChatCompletionMessage = completion.choices[0].message + match completion.choices[0].finish_reason: + case "tool_calls": + if (tool_calls := completion_message.tool_calls) and (tools := tools): + ctx.record(ResultTrace.of(tool_calls)) + return LMMToolRequests( + requests=[ + LMMToolRequest( + identifier=call.id, + tool=call.function.name, + arguments=json.loads(call.function.arguments), ) - ) - else: - streaming_update(update) - - with ctx.nested( - "openai_lmm_completion", - state=[ToolsUpdatesContext(send_update=streaming_update)], - metrics=[config], - ): - return LMMMessage( - role="assistant", - content=await _chat_stream( - client=client, - config=config, - messages=messages, - tools=tools or Toolbox(), - send_update=send_update, - ), + for call in tool_calls + ] ) + else: + raise OpenAIException("Invalid OpenAI completion", completion) + + case "stop": + if content := completion_message.content: + ctx.record(ResultTrace.of(content)) + # TODO: OpenAI models generating media? + return LMMCompletion.of(content) -def _convert_message( # noqa: PLR0912, C901, PLR0911 - config: OpenAIChatConfig, - message: LMMMessage, -) -> ChatCompletionMessageParam: - match message.role: - case "user": - if isinstance(message.content, str): - return { - "role": "user", - "content": message.content, - } - elif isinstance(message.content, ImageURLContent): - return { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": message.content.image_url, - "detail": config.vision_details or "auto", - }, - } - ], - } - elif isinstance(message.content, ImageBase64Content): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, Model): - return { - "role": "user", - "content": str(message.content), - } else: - content_parts: list[ChatCompletionContentPartParam] = [] - for part in message.content: - if isinstance(part, str): - content_parts.append( - { - "type": "text", - "text": part, - } - ) - elif isinstance(part, ImageURLContent): - content_parts.append( - { - "type": "image_url", - "image_url": { - "url": part.image_url, - "detail": config.vision_details or "auto", - }, - } + raise OpenAIException("Invalid OpenAI completion", completion) + + case other: + raise OpenAIException(f"Unexpected finish reason: {other}") + + +async def _chat_completion_stream( # noqa: C901, PLR0912, PLR0915 + *, + client: OpenAIClient, + config: OpenAIChatConfig, + messages: list[ChatCompletionMessageParam], + tools: Sequence[ToolSpecification] | None, + require_tool: ToolSpecification | bool, +) -> AsyncGenerator[LMMOutputStreamChunk, None]: + completion_stream: OpenAIAsyncStream[ChatCompletionChunk] + match require_tool: + case bool(required): + completion_stream = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools, + ), + tools_suggestion=required, + stream=True, + ) + + case ToolSpecification() as tool: + completion_stream = await client.chat_completion( + config=config, + messages=messages, + tools=cast( + list[ChatCompletionToolParam], + tools, + ), + tools_suggestion={ + "type": "function", + "function": { + "name": tool["function"]["name"], + }, + }, + stream=True, + ) + + accumulated_completion: str = "" + requested_tool_calls: list[ChoiceDeltaToolCall] = [] + async for part in completion_stream: + if choices := part.choices: # usage part does not contain choices + # we are always requesting single result - no need to take care of indices + element: Choice = choices[0] + if element.delta.content is not None: + part_text: str = element.delta.content + if not part_text: + continue # skip empty parts + accumulated_completion += part_text + # TODO: OpenAI models generating media? + yield LMMCompletionChunk.of(part_text) + + elif tool_calls := element.delta.tool_calls: + # tool calls come in parts, we have to merge them manually + for call in tool_calls: + try: + tool_call: ChoiceDeltaToolCall = next( + tool_call + for tool_call in requested_tool_calls + if tool_call.index == call.index ) - elif isinstance(message.content, ImageBase64Content): - raise ValueError("Unsupported message content", message) - elif isinstance(message.content, Model): - content_parts.append( - { - "type": "text", - "text": str(message.content), - } + + if call.id: + if tool_call.id is not None: + tool_call.id += call.id + else: + tool_call.id = call.id + else: + pass + + if call.function is None: + continue + + if tool_call.function is None: + tool_call.function = call.function + continue + + if call.function.name: + if tool_call.function.name is not None: + tool_call.function.name += call.function.name + else: + tool_call.function.name = call.function.name + else: + pass + + if call.function.arguments: + if tool_call.function.arguments is not None: + tool_call.function.arguments += call.function.arguments + else: + tool_call.function.arguments = call.function.arguments + else: + pass + + except (StopIteration, StopAsyncIteration): + requested_tool_calls.append(call) + + elif finish_reason := element.finish_reason: + match finish_reason: + case "tool_calls": + ctx.record(ResultTrace.of(requested_tool_calls)) + yield LMMToolRequests( + requests=[ + LMMToolRequest( + identifier=call.id or uuid4().hex, + tool=call.function.name, + arguments=json.loads(call.function.arguments) + if call.function.arguments + else {}, + ) + for call in requested_tool_calls + if call.function and call.function.name + ] ) - else: - raise ValueError("Unsupported message content", message) - return { - "role": "user", - "content": content_parts, - } - - case "assistant": - if isinstance(message.content, str): - return { - "role": "assistant", - "content": message.content, - } - elif isinstance(message.content, Model): - return { - "role": "assistant", - "content": str(message.content), - } - else: - raise ValueError("Invalid assistant message", message) - - case "system": - if isinstance(message.content, str): - return { - "role": "system", - "content": message.content, - } - elif isinstance(message.content, Model): - return { - "role": "system", - "content": str(message.content), - } + + case "stop": + ctx.record(ResultTrace.of(accumulated_completion)) + + case other: + raise OpenAIException(f"Unexpected finish reason: {other}") + else: - raise ValueError("Invalid system message", message) + ctx.log_warning("Unexpected OpenAI streaming part: %s", part) + + elif usage := part.usage: # record usage if able (expected in the last part) + ctx.record( + TokenUsage.for_model( + config.model, + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + ), + ) + + else: + ctx.log_warning("Unexpected OpenAI streaming part: %s", part) diff --git a/src/draive/scope/access.py b/src/draive/scope/access.py index 4bd61a3..339899d 100644 --- a/src/draive/scope/access.py +++ b/src/draive/scope/access.py @@ -1,11 +1,11 @@ from asyncio import Task, TaskGroup, current_task, shield -from collections.abc import Callable, Coroutine, Iterable -from contextvars import Context, ContextVar, Token, copy_context +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine, Iterable +from contextvars import ContextVar, Token from logging import Logger, getLogger from types import TracebackType from typing import Any, final -from draive.helpers import getenv_bool, mimic_function +from draive.helpers import AsyncStream, getenv_bool, mimic_function from draive.metrics import ( ExceptionTrace, Metric, @@ -78,13 +78,14 @@ async def __aexit__( exc=exc_val, tb=exc_tb, ) + except BaseException as exc: # record task group exceptions self._metrics.record(ExceptionTrace.of(exc_val or exc)) else: # or context exception - if exception := exc_val: + if (exception := exc_val) and exc_type is not GeneratorExit: self._metrics.record(ExceptionTrace.of(exception)) finally: @@ -150,7 +151,7 @@ def __exit__( ) -> None: if (token := self._metrics_token) and (metrics := self._metrics): _MetricsScope_Var.reset(self._metrics_token) - if exception := exc_val: + if (exception := exc_val) and exc_type is not GeneratorExit: metrics.record(ExceptionTrace.of(exception)) metrics.exit() @@ -333,13 +334,21 @@ def spawn_task[**Args, Result]( **kwargs: Args.kwargs, ) -> Task[Result]: nested_context: _PartialContext = ctx.nested(function.__name__) - current_context: Context = copy_context() async def wrapped(*args: Args.args, **kwargs: Args.kwargs) -> Result: with nested_context: return await function(*args, **kwargs) - return ctx._current_task_group().create_task(current_context.run(wrapped, *args, **kwargs)) + return ctx._current_task_group().create_task(wrapped(*args, **kwargs)) + + @staticmethod + def spawn_subtask[**Args, Result]( + function: Callable[Args, Coroutine[None, None, Result]], + /, + *args: Args.args, + **kwargs: Args.kwargs, + ) -> Task[Result]: + return ctx._current_task_group().create_task(function(*args, **kwargs)) @staticmethod def cancel() -> None: @@ -369,6 +378,30 @@ def nested( state=nested_state, ) + @staticmethod + def stream[Element]( + generator: AsyncGenerator[Element, None], + ) -> AsyncIterator[Element]: + # TODO: find better solution for streaming without spawning tasks if able + stream: AsyncStream[Element] = AsyncStream() + current_metrics: MetricsTrace = ctx._current_metrics() + current_metrics.enter() # ensure valid metrics scope closing + + async def iterate() -> None: + try: + async for element in generator: + stream.send(element) + + except BaseException as exc: + stream.finish(exception=exc) + else: + stream.finish() + finally: + current_metrics.exit() + + ctx.spawn_subtask(iterate) + return stream + @staticmethod def updated( *state: ParametrizedData, @@ -404,7 +437,18 @@ def read[Metric_T: Metric]( def record( *metrics: Metric, ) -> None: - ctx._current_metrics().record(*metrics) + try: + ctx._current_metrics().record(*metrics) + + # ignoring metrics record when using out of metrics context + # using default logger as fallback as we already know that we are missing metrics + except MissingScopeContext as exc: + logger: Logger = getLogger() + logger.error("Attempting to record metrics outside of metrics context") + logger.error( + exc, + exc_info=True, + ) @staticmethod def log_error( @@ -413,11 +457,25 @@ def log_error( *args: Any, exception: BaseException | None = None, ) -> None: - ctx._current_metrics().log_error( - message, - *args, - exception=exception, - ) + try: + ctx._current_metrics().log_error( + message, + *args, + exception=exception, + ) + + # using default logger as fallback when using out of metrics context + except MissingScopeContext: + logger: Logger = getLogger() + logger.error( + message, + *args, + ) + if exception := exception: + logger.error( + exception, + exc_info=True, + ) @staticmethod def log_warning( @@ -426,11 +484,25 @@ def log_warning( *args: Any, exception: Exception | None = None, ) -> None: - ctx._current_metrics().log_warning( - message, - *args, - exception=exception, - ) + try: + ctx._current_metrics().log_warning( + message, + *args, + exception=exception, + ) + + # using default logger as fallback when using out of metrics context + except MissingScopeContext: + logger: Logger = getLogger() + logger.warning( + message, + *args, + ) + if exception := exception: + logger.error( + exception, + exc_info=True, + ) @staticmethod def log_info( @@ -438,10 +510,18 @@ def log_info( /, *args: Any, ) -> None: - ctx._current_metrics().log_info( - message, - *args, - ) + try: + ctx._current_metrics().log_info( + message, + *args, + ) + + # using default logger as fallback when using out of metrics context + except MissingScopeContext: + getLogger().info( + message, + *args, + ) @staticmethod def log_debug( @@ -450,8 +530,22 @@ def log_debug( *args: Any, exception: Exception | None = None, ) -> None: - ctx._current_metrics().log_debug( - message, - *args, - exception=exception, - ) + try: + ctx._current_metrics().log_debug( + message, + *args, + exception=exception, + ) + + # using default logger as fallback when using out of metrics context + except MissingScopeContext: + logger: Logger = getLogger() + logger.debug( + message, + *args, + ) + if exception := exception: + logger.error( + exception, + exc_info=True, + ) diff --git a/src/draive/similarity/__init__.py b/src/draive/similarity/__init__.py index 369227c..7a0e549 100644 --- a/src/draive/similarity/__init__.py +++ b/src/draive/similarity/__init__.py @@ -1,7 +1,9 @@ -from draive.similarity.mmr import mmr_similarity -from draive.similarity.similarity import similarity +from draive.similarity.mmr import mmr_similarity_search +from draive.similarity.score import similarity_score +from draive.similarity.search import similarity_search __all__ = [ - "mmr_similarity", - "similarity", + "mmr_similarity_search", + "similarity_search", + "similarity_score", ] diff --git a/src/draive/similarity/mmr.py b/src/draive/similarity/mmr.py index bfbfb4f..b8b2c58 100644 --- a/src/draive/similarity/mmr.py +++ b/src/draive/similarity/mmr.py @@ -6,38 +6,38 @@ from draive.similarity.cosine import cosine __all__ = [ - "mmr_similarity", + "mmr_similarity_search", ] -def mmr_similarity( - query_embedding: NDArray[Any] | list[float], - alternatives_embeddings: list[NDArray[Any]] | list[list[float]], +def mmr_similarity_search( + query_vector: NDArray[Any] | list[float], + values_vectors: list[NDArray[Any]] | list[list[float]], limit: int, lambda_multiplier: float = 0.5, ) -> list[int]: assert limit > 0 # nosec: B101 - if not alternatives_embeddings: + if not values_vectors: return [] - query: NDArray[Any] = np.array(query_embedding) + query: NDArray[Any] = np.array(query_vector) if query.ndim == 1: - query = np.expand_dims(query_embedding, axis=0) - alternatives: NDArray[Any] = np.array(alternatives_embeddings) + query = np.expand_dims(query_vector, axis=0) + values: NDArray[Any] = np.array(values_vectors) # count similarity - similarity: NDArray[Any] = cosine(alternatives, query) + similarity: NDArray[Any] = cosine(values, query) # find most similar match for query most_similar: int = int(np.argmax(similarity)) selected_indices: list[int] = [most_similar] - selected: NDArray[Any] = np.array([alternatives[most_similar]]) + selected: NDArray[Any] = np.array([values[most_similar]]) # then look one by one next best matches until the limit or end of alternatives - while len(selected_indices) < limit and len(selected_indices) < len(alternatives_embeddings): + while len(selected_indices) < limit and len(selected_indices) < len(values_vectors): best_score: float = -np.inf best_index: int = -1 # count similarity to already selected results - similarity_to_selected: NDArray[Any] = cosine(alternatives, selected) + similarity_to_selected: NDArray[Any] = cosine(values, selected) # then find the next best score # (balancing between similarity to query and uniqueness of result) @@ -60,7 +60,7 @@ def mmr_similarity( selected_indices.append(best_index) selected = np.append( selected, - [alternatives[best_index]], # pyright: ignore[reportUnknownArgumentType] + [values[best_index]], # pyright: ignore[reportUnknownArgumentType] axis=0, ) diff --git a/src/draive/similarity/score.py b/src/draive/similarity/score.py new file mode 100644 index 0000000..c8cdd3f --- /dev/null +++ b/src/draive/similarity/score.py @@ -0,0 +1,24 @@ +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from draive.similarity.cosine import cosine + +__all__ = [ + "similarity_score", +] + + +async def similarity_score( + value_vector: NDArray[Any] | list[float], + reference_vector: NDArray[Any] | list[float], +) -> float: + reference: NDArray[Any] = np.array(reference_vector) + if reference.ndim == 1: + reference = np.expand_dims(reference_vector, axis=0) + value: NDArray[Any] = np.array(value_vector) + if value.ndim == 1: + value = np.expand_dims(value_vector, axis=0) + + return cosine(value, reference)[0] diff --git a/src/draive/similarity/search.py b/src/draive/similarity/search.py new file mode 100644 index 0000000..565b3bf --- /dev/null +++ b/src/draive/similarity/search.py @@ -0,0 +1,38 @@ +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from draive.similarity.cosine import cosine + +__all__ = [ + "similarity_search", +] + + +def similarity_search( + query_vector: NDArray[Any] | list[float], + values_vectors: list[NDArray[Any]] | list[list[float]], + limit: int, + score_threshold: float | None = None, +) -> list[int]: + assert limit > 0 # nosec: B101 + if not values_vectors: + return [] + query: NDArray[Any] = np.array(query_vector) + if query.ndim == 1: + query = np.expand_dims(query_vector, axis=0) + values: NDArray[Any] = np.array(values_vectors) + matching_scores: NDArray[Any] = cosine(values, query) + sorted_indices: list[int] = list(reversed(np.argsort(matching_scores))) + if score_threshold: + return [ + int(idx) + for idx in sorted_indices # pyright: ignore[reportUnknownVariableType] + ][:limit] + else: + return [ + int(idx) + for idx in sorted_indices # pyright: ignore[reportUnknownVariableType] + if matching_scores[idx] > score_threshold # pyright: ignore[reportUnknownArgumentType] + ][:limit] diff --git a/src/draive/similarity/similarity.py b/src/draive/similarity/similarity.py deleted file mode 100644 index 4553199..0000000 --- a/src/draive/similarity/similarity.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Any - -import numpy as np -from numpy.typing import NDArray - -from draive.similarity.cosine import cosine - -__all__ = [ - "similarity", -] - - -def similarity( - query_embedding: NDArray[Any] | list[float], - alternatives_embeddings: list[NDArray[Any]] | list[list[float]], - limit: int, - score_threshold: float, -) -> list[int]: - assert limit > 0 # nosec: B101 - if not alternatives_embeddings: - return [] - query: NDArray[Any] = np.array(query_embedding) - if query.ndim == 1: - query = np.expand_dims(query_embedding, axis=0) - alternatives: NDArray[Any] = np.array(alternatives_embeddings) - matching_scores: NDArray[Any] = cosine(alternatives, query) - sorted_indices: list[int] = list(reversed(np.argsort(matching_scores))) - return [ - int(idx) - for idx in sorted_indices # pyright: ignore[reportUnknownVariableType] - if matching_scores[idx] > score_threshold # pyright: ignore[reportUnknownArgumentType] - ][:limit] diff --git a/src/draive/tools/__init__.py b/src/draive/tools/__init__.py deleted file mode 100644 index f88d46f..0000000 --- a/src/draive/tools/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from draive.tools.errors import ToolException -from draive.tools.state import ( - ToolCallContext, - ToolsUpdatesContext, -) -from draive.tools.tool import Tool, tool -from draive.tools.toolbox import Toolbox -from draive.tools.update import ToolCallStatus, ToolCallUpdate - -__all__ = [ - "tool", - "Tool", - "Toolbox", - "ToolCallContext", - "ToolCallStatus", - "ToolCallUpdate", - "ToolException", - "ToolsUpdatesContext", -] diff --git a/src/draive/tools/state.py b/src/draive/tools/state.py deleted file mode 100644 index c45d015..0000000 --- a/src/draive/tools/state.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Callable - -from draive.scope import ctx -from draive.tools.update import ToolCallUpdate -from draive.types import Model, State - -__all__ = [ - "ToolCallContext", - "ToolsUpdatesContext", -] - - -class ToolsUpdatesContext(State): - send_update: Callable[[ToolCallUpdate], None] | None = None - - -class ToolCallContext(State): - call_id: str - tool: str - - def send_update( - self, - content: Model, - ) -> None: - if send_update := ctx.state(ToolsUpdatesContext).send_update: - send_update( - ToolCallUpdate( - call_id=self.call_id, - tool=self.tool, - status="RUNNING", - content=content, - ) - ) diff --git a/src/draive/tools/toolbox.py b/src/draive/tools/toolbox.py deleted file mode 100644 index 3f10a6c..0000000 --- a/src/draive/tools/toolbox.py +++ /dev/null @@ -1,81 +0,0 @@ -from json import loads -from typing import Any, Literal, final - -from draive.helpers import freeze -from draive.parameters import ToolSpecification -from draive.tools import Tool -from draive.tools.errors import ToolException -from draive.types import MultimodalContent - -__all__ = [ - "Toolbox", -] - -AnyTool = Tool[Any, Any] - - -@final -class Toolbox: - def __init__( - self, - *tools: AnyTool, - suggest: AnyTool | Literal[True] | None = None, - ) -> None: - self._tools: dict[str, AnyTool] = {tool.name: tool for tool in tools} - self.suggest_tools: bool - self._suggested_tool: AnyTool | None - match suggest: - case None: - self.suggest_tools = False - self._suggested_tool = None - case True: - self.suggest_tools = True if self._tools else False - self._suggested_tool = None - case tool: - self.suggest_tools = True - self._suggested_tool = tool - self._tools[tool.name] = tool - - freeze(self) - - @property - def suggested_tool_name(self) -> str | None: - if self._suggested_tool is not None and self._suggested_tool.available: - return self._suggested_tool.name - else: - return None - - @property - def suggested_tool(self) -> ToolSpecification | None: - if self._suggested_tool is not None and self._suggested_tool.available: - return self._suggested_tool.specification - else: - return None - - @property - def available_tools(self) -> list[ToolSpecification]: - return [tool.specification for tool in self._tools.values() if tool.available] - - def requires_direct_result( - self, - tool_name: str, - ) -> bool: - if tool := self._tools.get(tool_name): - return tool.requires_direct_result - else: - return False - - async def call_tool( - self, - name: str, - /, - call_id: str, - arguments: dict[str, Any] | str | bytes | None, - ) -> MultimodalContent: - if tool := self._tools.get(name): - return await tool.call( - call_id, - **loads(arguments) if isinstance(arguments, str | bytes) else arguments or {}, - ) - else: - raise ToolException("Requested tool is not defined", name) diff --git a/src/draive/tools/update.py b/src/draive/tools/update.py deleted file mode 100644 index 890b0de..0000000 --- a/src/draive/tools/update.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Literal - -from draive.types.model import Model - -__all__ = [ - "ToolCallStatus", - "ToolCallUpdate", -] - - -ToolCallStatus = Literal["STARTED", "RUNNING", "FINISHED", "FAILED"] - - -class ToolCallUpdate(Model): - call_id: str - tool: str - status: ToolCallStatus - content: Model | None diff --git a/src/draive/types/__init__.py b/src/draive/types/__init__.py index bd7d0f5..f87be9a 100644 --- a/src/draive/types/__init__.py +++ b/src/draive/types/__init__.py @@ -1,38 +1,56 @@ -from draive.types.audio import AudioBase64Content, AudioContent, AudioURLContent -from draive.types.images import ImageBase64Content, ImageContent, ImageURLContent +from draive.types.audio import AudioBase64Content, AudioContent, AudioDataContent, AudioURLContent +from draive.types.images import ImageBase64Content, ImageContent, ImageDataContent, ImageURLContent +from draive.types.instruction import Instruction +from draive.types.lmm import ( + LMMCompletion, + LMMCompletionChunk, + LMMContextElement, + LMMInput, + LMMInstruction, + LMMOutput, + LMMOutputStream, + LMMOutputStreamChunk, + LMMToolRequest, + LMMToolRequests, + LMMToolResponse, +) from draive.types.memory import Memory, ReadOnlyMemory from draive.types.model import Model -from draive.types.multimodal import ( - MultimodalContent, - MultimodalContentItem, - has_media, - is_multimodal_content, - merge_multimodal_content, - multimodal_content_string, -) +from draive.types.multimodal import MultimodalContent, MultimodalContentElement from draive.types.state import State -from draive.types.video import VideoBase64Content, VideoContent, VideoURLContent +from draive.types.tool_status import ToolCallStatus +from draive.types.video import VideoBase64Content, VideoContent, VideoDataContent, VideoURLContent __all__ = [ "AudioBase64Content", "AudioContent", + "AudioDataContent", "AudioURLContent", - "has_media", + "Instruction", "ImageBase64Content", "ImageContent", + "ImageDataContent", "ImageURLContent", - "is_multimodal_content", - "is_multimodal_content", "Memory", - "merge_multimodal_content", - "merge_multimodal_content", "Model", - "multimodal_content_string", "MultimodalContent", - "MultimodalContentItem", + "MultimodalContentElement", "ReadOnlyMemory", "State", + "ToolCallStatus", "VideoBase64Content", "VideoContent", + "VideoDataContent", "VideoURLContent", + "LMMCompletion", + "LMMCompletionChunk", + "LMMContextElement", + "LMMInput", + "LMMInstruction", + "LMMOutput", + "LMMOutputStream", + "LMMOutputStreamChunk", + "LMMToolRequest", + "LMMToolRequests", + "LMMToolResponse", ] diff --git a/src/draive/types/audio.py b/src/draive/types/audio.py index d7cc789..b74fbe7 100644 --- a/src/draive/types/audio.py +++ b/src/draive/types/audio.py @@ -3,6 +3,7 @@ __all__ = [ "AudioBase64Content", "AudioContent", + "AudioDataContent", "AudioURLContent", ] @@ -17,4 +18,9 @@ class AudioBase64Content(Model): audio_transcription: str | None = None -AudioContent = AudioURLContent | AudioBase64Content +class AudioDataContent(Model): + audio_data: bytes + audio_transcription: str | None = None + + +AudioContent = AudioURLContent | AudioBase64Content | AudioDataContent diff --git a/src/draive/types/images.py b/src/draive/types/images.py index cdd3548..9277cfa 100644 --- a/src/draive/types/images.py +++ b/src/draive/types/images.py @@ -3,6 +3,7 @@ __all__ = [ "ImageBase64Content", "ImageContent", + "ImageDataContent", "ImageURLContent", ] @@ -17,4 +18,9 @@ class ImageBase64Content(Model): image_description: str | None = None -ImageContent = ImageURLContent | ImageBase64Content +class ImageDataContent(Model): + image_data: bytes + image_description: str | None = None + + +ImageContent = ImageURLContent | ImageBase64Content | ImageDataContent diff --git a/src/draive/types/instruction.py b/src/draive/types/instruction.py new file mode 100644 index 0000000..2589b69 --- /dev/null +++ b/src/draive/types/instruction.py @@ -0,0 +1,86 @@ +from typing import Self +from uuid import uuid4 + +from draive.helpers import freeze + +__all__ = [ + "Instruction", +] + + +class Instruction: + def __init__( + self, + instruction: str, + /, + identifier: str | None = None, + **variables: object, + ) -> None: + self.instruction: str = instruction + self.identifier: str = identifier or uuid4().hex + self.variables: dict[str, object] = variables + + freeze(self) + + def format( + self, + **variables: object, + ) -> str: + if variables: + return self.instruction.format_map( + { + **self.variables, + **variables, + }, + ) + + else: + return self.instruction.format_map(self.variables) + + def extended( + self, + instruction: str, + /, + joiner: str | None = None, + **variables: object, + ) -> Self: + if variables: + return self.__class__( + (joiner or " ").join((self.instruction, instruction)), + identifier=self.identifier, + **{ + **self.variables, + **variables, + }, + ) + + else: + return self.__class__( + (joiner or " ").join((self.instruction, instruction)), + identifier=self.identifier, + **self.variables, + ) + + def updated( + self, + **variables: object, + ) -> Self: + if variables: + return self.__class__( + self.instruction, + identifier=self.identifier, + **{ + **self.variables, + **variables, + }, + ) + + else: + return self + + def __str__(self) -> str: + try: + return self.format() + + except KeyError: + return self.instruction diff --git a/src/draive/types/lmm.py b/src/draive/types/lmm.py new file mode 100644 index 0000000..496c82e --- /dev/null +++ b/src/draive/types/lmm.py @@ -0,0 +1,110 @@ +from collections.abc import AsyncIterator +from typing import Any, Self + +from draive.parameters import Field +from draive.types.instruction import Instruction +from draive.types.model import Model +from draive.types.multimodal import MultimodalContent, MultimodalContentElement + +__all__ = [ + "LMMCompletion", + "LMMCompletionChunk", + "LMMContextElement", + "LMMInput", + "LMMInstruction", + "LMMOutput", + "LMMOutputStream", + "LMMOutputStreamChunk", + "LMMToolRequest", + "LMMToolRequests", + "LMMToolResponse", +] + + +class LMMInstruction(Model): + @classmethod + def of( + cls, + instruction: Instruction | str, + /, + **variables: object, + ) -> Self: + match instruction: + case str(content): + return cls(content=content.format_map(variables) if variables else content) + + case instruction: + return cls(content=instruction.format(**variables)) + + content: str + + def __bool__(self) -> bool: + return bool(self.content) + + +class LMMInput(Model): + @classmethod + def of( + cls, + content: MultimodalContent | MultimodalContentElement, + /, + ) -> Self: + return cls(content=MultimodalContent.of(content)) + + content: MultimodalContent + + def __bool__(self) -> bool: + return bool(self.content) + + +class LMMCompletion(Model): + @classmethod + def of( + cls, + content: MultimodalContent | MultimodalContentElement, + /, + ) -> Self: + return cls(content=MultimodalContent.of(content)) + + content: MultimodalContent + + def __bool__(self) -> bool: + return bool(self.content) + + +class LMMCompletionChunk(Model): + @classmethod + def of( + cls, + content: MultimodalContent | MultimodalContentElement, + /, + ) -> Self: + return cls(content=MultimodalContent.of(content)) + + content: MultimodalContent + + def __bool__(self) -> bool: + return bool(self.content) + + +class LMMToolResponse(Model): + identifier: str + tool: str + content: MultimodalContent + direct: bool + + +class LMMToolRequest(Model): + identifier: str + tool: str + arguments: dict[str, Any] = Field(default_factory=dict) + + +class LMMToolRequests(Model): + requests: list[LMMToolRequest] + + +LMMContextElement = LMMInstruction | LMMInput | LMMCompletion | LMMToolRequests | LMMToolResponse +LMMOutput = LMMCompletion | LMMToolRequests +LMMOutputStreamChunk = LMMCompletionChunk | LMMToolRequests +LMMOutputStream = AsyncIterator[LMMOutputStreamChunk] diff --git a/src/draive/types/memory.py b/src/draive/types/memory.py index 7092d20..0c26799 100644 --- a/src/draive/types/memory.py +++ b/src/draive/types/memory.py @@ -15,8 +15,7 @@ async def recall(self) -> list[Element]: ... @abstractmethod async def remember( self, - elements: Iterable[Element], - /, + *elements: Element, ) -> None: ... @@ -24,13 +23,13 @@ async def remember( class ReadOnlyMemory[Element](Memory[Element]): def __init__( self, - elements: (Callable[[], Awaitable[list[Element]]] | Iterable[Element]), + elements: Callable[[], Awaitable[list[Element]]] | Iterable[Element] | None = None, ) -> None: self._elements: Callable[[], Awaitable[list[Element]]] if callable(elements): self._elements = elements else: - messages_list: list[Element] = list(elements) + messages_list: list[Element] = list(elements) if elements is not None else [] async def constant() -> list[Element]: return messages_list @@ -42,7 +41,6 @@ async def recall(self) -> list[Element]: async def remember( self, - elements: Iterable[Element], - /, + *elements: Element, ) -> None: pass # ignore diff --git a/src/draive/types/multimodal.py b/src/draive/types/multimodal.py index 3c4ba12..bde0a1b 100644 --- a/src/draive/types/multimodal.py +++ b/src/draive/types/multimodal.py @@ -1,128 +1,158 @@ -from typing import Any, TypeGuard +from itertools import chain +from typing import Self, final -from draive.types.audio import AudioBase64Content, AudioContent, AudioURLContent -from draive.types.images import ImageBase64Content, ImageContent, ImageURLContent -from draive.types.video import VideoBase64Content, VideoContent, VideoURLContent +from draive.types.audio import AudioBase64Content, AudioContent, AudioDataContent, AudioURLContent +from draive.types.images import ImageBase64Content, ImageContent, ImageDataContent, ImageURLContent +from draive.types.model import Model +from draive.types.video import VideoBase64Content, VideoContent, VideoDataContent, VideoURLContent __all__ = [ - "is_multimodal_content", - "merge_multimodal_content", - "multimodal_content_string", - "has_media", "MultimodalContent", - "MultimodalContentItem", + "MultimodalContentElement", ] -MultimodalContentItem = VideoContent | ImageContent | AudioContent | str -MultimodalContent = tuple[MultimodalContentItem, ...] | MultimodalContentItem - - -def is_multimodal_content( # noqa: PLR0911 - candidate: Any, +MultimodalContentElement = VideoContent | ImageContent | AudioContent | str + + +@final +class MultimodalContent(Model): + @classmethod + def of( + cls, + *elements: Self | MultimodalContentElement, + ) -> Self: + match elements: + case [MultimodalContent() as content]: + return content + + case elements: + return cls( + elements=tuple(chain.from_iterable(_extract(element) for element in elements)), + ) + + elements: tuple[MultimodalContentElement, ...] + + @property + def has_media(self) -> bool: + return any(_is_media(element) for element in self.elements) + + def as_string( + self, + joiner: str | None = None, + ) -> str: + return (joiner or "\n").join(_as_string(element) for element in self.elements) + + def appending( + self, + *elements: MultimodalContentElement, + ) -> Self: + return self.__class__( + elements=( + *self.elements, + *elements, + ) + ) + + def extending( + self, + *other: Self, + ) -> Self: + return self.__class__( + elements=( + *self.elements, + *(element for content in other for element in content.elements), + ) + ) + + def joining_texts( + self, + joiner: str | None = None, + ) -> Self: + joined_elements: list[MultimodalContentElement] = [] + current_text: str | None = None + for element in self.elements: + match element: + case str() as string: + if current_text: + current_text = (joiner or "\n").join((current_text, string)) + + else: + current_text = string + + case other: + if current_text: + joined_elements.append(current_text) + current_text = None + + joined_elements.append(other) + + return self.__class__( + elements=tuple(joined_elements), + ) + + def __bool__(self) -> bool: + return bool(self.elements) and any(self.elements) + + +def _extract( + element: MultimodalContent | MultimodalContentElement, /, -) -> TypeGuard[MultimodalContent]: - match candidate: - case str(): - return True - - case ImageURLContent(): - return True - - case ImageBase64Content(): - return True - - case AudioURLContent(): - return True +) -> tuple[MultimodalContentElement, ...]: + match element: + case MultimodalContent() as content: + return content.elements - case AudioBase64Content(): - return True - - case VideoURLContent(): - return True - - case VideoBase64Content(): - return True + case element: + return (element,) - case [*elements] if isinstance(candidate, tuple): - return all(is_multimodal_content(element) for element in elements) - case _: - return False - - -def has_media( # noqa: PLR0911 - content: MultimodalContent, - /, +def _is_media( + element: MultimodalContentElement, ) -> bool: - match content: + match element: case str(): return False - case ImageURLContent(): - return True - - case ImageBase64Content(): - return True - - case AudioURLContent(): - return True - - case AudioBase64Content(): - return True - - case VideoURLContent(): - return True - - case VideoBase64Content(): + case _: return True - case [*elements]: - return any(has_media(element) for element in elements) - -def multimodal_content_string( # noqa: PLR0911 - content: MultimodalContent, - /, +def _as_string( # noqa: PLR0911 + element: MultimodalContentElement, ) -> str: - match content: + match element: case str() as string: return string case ImageURLContent() as image_url: - return image_url.image_description or f"[IMAGE]({image_url.image_url})" + return f"![{image_url.image_description or 'IMAGE'}]({image_url.image_url})" case ImageBase64Content() as image_base64: - return image_base64.image_description or "[IMAGE]()" + # we might want to use base64 content directly, but it would make a lot of tokens... + return f"![{image_base64.image_description or 'IMAGE'}]()" + + case ImageDataContent() as image_data: + # we might want to convert to base64 content, but it would make a lot of tokens... + return f"![{image_data.image_description or 'IMAGE'}]()" case AudioURLContent() as audio_url: - return audio_url.audio_transcription or f"[AUDIO]({audio_url.audio_url})" + return f"![{audio_url.audio_transcription or 'AUDIO'}]({audio_url.audio_url})" case AudioBase64Content() as audio_base64: - return audio_base64.audio_transcription or "[AUDIO]()" + # we might want to use base64 content directly, but it would make a lot of tokens... + return f"![{audio_base64.audio_transcription or 'AUDIO'}]()" + + case AudioDataContent() as audio_data: + # we might want to convert to base64 content, but it would make a lot of tokens... + return f"![{audio_data.audio_transcription or 'AUDIO'}]()" case VideoURLContent() as video_url: - return video_url.video_transcription or f"[VIDEO]({video_url.video_url})" + return f"![{video_url.video_transcription or 'VIDEO'}]({video_url.video_url})" case VideoBase64Content() as video_base64: - return video_base64.video_transcription or "[VIDEO]()" - - case [*elements]: - return "\n".join(multimodal_content_string(element) for element in elements) - - -def merge_multimodal_content( - *content: MultimodalContent | None, -) -> tuple[MultimodalContentItem, ...]: - result: list[MultimodalContentItem] = [] - for part in content: - match part: - case None: - continue # skip none - - case [*parts]: - result.extend(parts) - - case part: - result.append(part) + # we might want to use base64 content directly, but it would make a lot of tokens... + return f"![{video_base64.video_transcription or 'VIDEO'}]()" - return tuple(result) + case VideoDataContent() as video_data: + # we might want to convert to base64 content, but it would make a lot of tokens... + return f"![{video_data.video_transcription or 'VIDEO'}]()" diff --git a/src/draive/types/tool_status.py b/src/draive/types/tool_status.py new file mode 100644 index 0000000..bea9e07 --- /dev/null +++ b/src/draive/types/tool_status.py @@ -0,0 +1,19 @@ +from typing import Literal + +from draive.types.model import Model + +__all__ = [ + "ToolCallStatus", +] + + +class ToolCallStatus(Model): + identifier: str + tool: str + status: Literal[ + "STARTED", + "RUNNING", + "FINISHED", + "FAILED", + ] + content: dict[str, object] | None = None diff --git a/src/draive/types/video.py b/src/draive/types/video.py index 6f6e826..5947bb7 100644 --- a/src/draive/types/video.py +++ b/src/draive/types/video.py @@ -3,6 +3,7 @@ __all__ = [ "VideoBase64Content", "VideoContent", + "VideoDataContent", "VideoURLContent", ] @@ -17,4 +18,9 @@ class VideoBase64Content(Model): video_transcription: str | None = None -VideoContent = VideoURLContent | VideoBase64Content +class VideoDataContent(Model): + video_data: bytes + video_transcription: str | None = None + + +VideoContent = VideoURLContent | VideoBase64Content | VideoDataContent diff --git a/src/draive/utils/__init__.py b/src/draive/utils/__init__.py index da4ca21..3884b90 100644 --- a/src/draive/utils/__init__.py +++ b/src/draive/utils/__init__.py @@ -1,15 +1,11 @@ from draive.utils.cache import cache -from draive.utils.early_exit import allowing_early_exit, with_early_exit from draive.utils.retry import auto_retry -from draive.utils.stream import AsyncStream, AsyncStreamTask +from draive.utils.stream import AsyncStreamTask from draive.utils.trace import traced __all__ = [ - "allowing_early_exit", - "AsyncStream", "AsyncStreamTask", "auto_retry", "cache", "traced", - "with_early_exit", ] diff --git a/src/draive/utils/early_exit.py b/src/draive/utils/early_exit.py deleted file mode 100644 index f2f5109..0000000 --- a/src/draive/utils/early_exit.py +++ /dev/null @@ -1,89 +0,0 @@ -from asyncio import ( - FIRST_COMPLETED, - Future, - InvalidStateError, - get_running_loop, - wait, -) -from collections.abc import Callable, Coroutine -from contextvars import ContextVar, Token -from typing import Any, Protocol, Self - -from draive.scope import ctx -from draive.types import Model - -__all__ = [ - "allowing_early_exit", - "with_early_exit", -] - - -async def allowing_early_exit[**Args, Result, EarlyResult]( - result: type[EarlyResult], - call: Callable[Args, Coroutine[None, None, Result]], - /, - *args: Args.args, - **kwargs: Args.kwargs, -) -> Result | EarlyResult: - early_exit_future: Future[EarlyResult] = get_running_loop().create_future() - - async def exit_early(early_result: EarlyResult) -> None: - if not isinstance(early_result, result): - return ctx.log_debug( - "Ignored attempt to early exit with unexpected result: %s", - type(early_result), - ) - - try: - early_exit_future.set_result(early_result) - ctx.record(_EarlyExitResultTrace.of(early_result)) - except InvalidStateError as exc: - ctx.log_debug("Ignored redundant attempt to early exit: %s", exc) - - early_exit_token: Token[_RequestEarlyExit] = _EarlyExit_Var.set(exit_early) - try: - finished, running = await wait( - [ - ctx.spawn_task(call, *args, **kwargs), - early_exit_future, - ], - return_when=FIRST_COMPLETED, - ) - - for task in running: # pyright: ignore[reportUnknownVariableType] - task.cancel() - - return finished.pop().result() - - finally: - _EarlyExit_Var.reset(early_exit_token) - - -async def with_early_exit[Result](result: Result) -> Result: - try: - await _EarlyExit_Var.get()(early_result=result) - except LookupError as exc: - ctx.log_debug("Requested early exit in context not allowing it: %s", exc) - return result - - -class _RequestEarlyExit(Protocol): - async def __call__( - self, - early_result: Any, - ) -> None: ... - - -class _EarlyExitResultTrace(Model): - @classmethod - def of( - cls, - value: Any, - /, - ) -> Self: - return cls(result=value) - - result: Any - - -_EarlyExit_Var = ContextVar[_RequestEarlyExit]("_EarlyExit_Var") diff --git a/src/draive/utils/stream.py b/src/draive/utils/stream.py index 8eeb6da..2847eaa 100644 --- a/src/draive/utils/stream.py +++ b/src/draive/utils/stream.py @@ -1,90 +1,15 @@ -from asyncio import AbstractEventLoop, CancelledError, Future, Task, get_running_loop -from collections import deque +from asyncio import Task from collections.abc import AsyncIterator, Callable, Coroutine from typing import Self +from draive.helpers import AsyncStream from draive.scope import ctx __all__ = [ - "AsyncStream", "AsyncStreamTask", ] -class AsyncStream[Element](AsyncIterator[Element]): - def __init__( - self, - loop: AbstractEventLoop | None = None, - ) -> None: - self._loop: AbstractEventLoop = loop or get_running_loop() - self._buffer: deque[Element] = deque() - self._waiting_queue: deque[Future[Element]] = deque() - self._finish_exception: BaseException | None = None - - def __del__(self) -> None: - while self._waiting_queue: - waiting: Future[Element] = self._waiting_queue.popleft() - if waiting.done(): - continue - else: - waiting.set_exception(CancelledError()) - - @property - def finished(self) -> bool: - return self._finish_exception is not None - - def send( - self, - element: Element, - ) -> None: - if self.finished: - raise RuntimeError("AsyncStream has been already finished") - - while self._waiting_queue: - assert not self._buffer # nosec: B101 - waiting: Future[Element] = self._waiting_queue.popleft() - if waiting.done(): - continue - else: - waiting.set_result(element) - break - else: - self._buffer.append(element) - - def finish( - self, - exception: BaseException | None = None, - ) -> None: - if self.finished: - raise RuntimeError("AsyncStream has been already finished") - self._finish_exception = exception or StopAsyncIteration() - if self._buffer: - assert self._waiting_queue is None # nosec: B101 - return # allow consuming buffer to the end - while self._waiting_queue: - waiting: Future[Element] = self._waiting_queue.popleft() - if waiting.done(): - continue - else: - waiting.set_exception(self._finish_exception) - - def __aiter__(self) -> Self: - return self - - async def __anext__(self) -> Element: - if self._buffer: # use buffer first - return self._buffer.popleft() - if finish_exception := self._finish_exception: # check if finished - raise finish_exception - - # create new waiting future - future: Future[Element] = self._loop.create_future() - self._waiting_queue.append(future) - - # wait for the result - return await future - - class AsyncStreamTask[Element](AsyncIterator[Element]): def __init__( self, @@ -93,7 +18,7 @@ def __init__( stream: AsyncStream[Element] = AsyncStream() self._stream: AsyncStream[Element] = stream - async def streaming_job() -> None: + async def streaming() -> None: try: await job(stream.send) except Exception as exc: @@ -101,7 +26,7 @@ async def streaming_job() -> None: else: stream.finish() - self._task: Task[None] = ctx.spawn_task(streaming_job) + self._task: Task[None] = ctx.spawn_task(streaming) def __del__(self) -> None: self._task.cancel() diff --git a/tests/test_model.py b/tests/test_model.py index 111b373..fb921e5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,11 +6,12 @@ from draive import ( MISSING, AudioURLContent, + ConversationMessage, Field, ImageURLContent, - LMMMessage, Missing, Model, + MultimodalContent, ) @@ -151,59 +152,98 @@ def test_basic_decoding() -> None: assert BasicsModel.from_json(basic_model_json) == basic_model_instance -basic_lmm_message_instance: LMMMessage = LMMMessage( - role="assistant", - content="string", +basic_conversation_message_instance: ConversationMessage = ConversationMessage( + identifier="identifier", + role="model", + content=MultimodalContent.of("string"), +) +basic_conversation_message_json: str = ( + "{" + '"identifier": "identifier", ' + '"role": "model", ' + '"author": null, ' + '"created": null, ' + '"content": {' + '"elements": ["string"]' + "}}" ) -basic_lmm_message_json: str = '{"role": "assistant", "content": "string"}' -image_lmm_message_instance: LMMMessage = LMMMessage( - role="assistant", - content=ImageURLContent(image_url="https://miquido.com/image"), +image_conversation_message_instance: ConversationMessage = ConversationMessage( + identifier="identifier", + role="model", + content=MultimodalContent.of(ImageURLContent(image_url="https://miquido.com/image")), ) -image_lmm_message_json: str = ( - '{"role": "assistant",' - ' "content": {"image_url": "https://miquido.com/image", "image_description": null}' - "}" +image_conversation_message_json: str = ( + "{" + '"identifier": "identifier", ' + '"role": "model", ' + '"author": null, ' + '"created": null, ' + '"content": {' + '"elements": [{"image_url": "https://miquido.com/image", "image_description": null}]' + "}}" ) -audio_lmm_message_instance: LMMMessage = LMMMessage( - role="assistant", - content=AudioURLContent(audio_url="https://miquido.com/audio"), +audio_conversation_message_instance: ConversationMessage = ConversationMessage( + identifier="identifier", + role="model", + content=MultimodalContent.of(AudioURLContent(audio_url="https://miquido.com/audio")), ) -audio_lmm_message_json: str = ( - '{"role": "assistant",' - ' "content": {"audio_url": "https://miquido.com/audio", "audio_transcription": null}' - "}" +audio_conversation_message_json: str = ( + "{" + '"identifier": "identifier", ' + '"role": "model", ' + '"author": null, ' + '"created": null, ' + '"content": {' + '"elements": [{"audio_url": "https://miquido.com/audio", "audio_transcription": null}]' + "}}" ) -mixed_lmm_message_instance: LMMMessage = LMMMessage( - role="assistant", - content=( +mixed_conversation_message_instance: ConversationMessage = ConversationMessage( + identifier="identifier", + role="model", + content=MultimodalContent.of( AudioURLContent(audio_url="https://miquido.com/audio"), "string", ImageURLContent(image_url="https://miquido.com/image"), "content", ), ) -mixed_lmm_message_json: str = ( - '{"role": "assistant",' - ' "content": [' - '{"audio_url": "https://miquido.com/audio", "audio_transcription": null},' - ' "string",' - ' {"image_url": "https://miquido.com/image", "image_description": null},' - ' "content"' - "]}" +mixed_conversation_message_json: str = ( + "{" + '"identifier": "identifier", ' + '"role": "model", ' + '"author": null, ' + '"created": null, ' + '"content": {"elements": [' + '{"audio_url": "https://miquido.com/audio", "audio_transcription": null}, ' + '"string", ' + '{"image_url": "https://miquido.com/image", "image_description": null}, ' + '"content"' + "]}}" ) def test_llm_message_decoding() -> None: - assert LMMMessage.from_json(basic_lmm_message_json) == basic_lmm_message_instance - assert LMMMessage.from_json(image_lmm_message_json) == image_lmm_message_instance - assert LMMMessage.from_json(audio_lmm_message_json) == audio_lmm_message_instance - assert LMMMessage.from_json(mixed_lmm_message_json) == mixed_lmm_message_instance + assert ( + ConversationMessage.from_json(basic_conversation_message_json) + == basic_conversation_message_instance + ) + assert ( + ConversationMessage.from_json(image_conversation_message_json) + == image_conversation_message_instance + ) + assert ( + ConversationMessage.from_json(audio_conversation_message_json) + == audio_conversation_message_instance + ) + assert ( + ConversationMessage.from_json(mixed_conversation_message_json) + == mixed_conversation_message_instance + ) def test_llm_message_encoding() -> None: - assert basic_lmm_message_instance.as_json() == basic_lmm_message_json - assert image_lmm_message_instance.as_json() == image_lmm_message_json - assert audio_lmm_message_instance.as_json() == audio_lmm_message_json - assert mixed_lmm_message_instance.as_json() == mixed_lmm_message_json + assert basic_conversation_message_instance.as_json() == basic_conversation_message_json + assert image_conversation_message_instance.as_json() == image_conversation_message_json + assert audio_conversation_message_instance.as_json() == audio_conversation_message_json + assert mixed_conversation_message_instance.as_json() == mixed_conversation_message_json diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index 85b0120..106e53b 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from draive import ToolException, auto_retry, cache, ctx, tool +from draive import MultimodalContent, auto_retry, cache, ctx, tool from pytest import mark, raises @@ -34,14 +34,14 @@ async def compute(value: int) -> int: executions += 1 raise FakeException() - with raises(ToolException): + with raises(FakeException): await compute(value=42) assert executions == 1 @mark.asyncio @ctx.wrap("test") -async def test_formatted_call_returns_multimodal_content(): +async def test_toolbox_call_returns_multimodal_content(): executions: int = 0 @tool @@ -50,13 +50,18 @@ async def compute(value: int) -> int: executions += 1 return value - assert await compute.call("call_id", value=42) == "42" + assert await compute._toolbox_call( + "call_id", + arguments={ + "value": 42, + }, + ) == MultimodalContent.of("42") assert executions == 1 @mark.asyncio @ctx.wrap("test") -async def test_formatted_call_returns_custom_content(): +async def test_toolbox_call_returns_custom_content(): executions: int = 0 def custom_format(value: int) -> str: @@ -68,7 +73,12 @@ async def compute(value: int) -> int: executions += 1 return value - assert await compute.call("call_id", value=42) == "Value:42" + assert await compute._toolbox_call( + "call_id", + arguments={ + "value": 42, + }, + ) == MultimodalContent.of("Value:42") assert executions == 1