diff --git a/constraints b/constraints index 922e67c..b203f0b 100644 --- a/constraints +++ b/constraints @@ -22,10 +22,8 @@ h11==0.14.0 # via httpcore httpcore==1.0.5 # via httpx -httpx==0.25.2 - # via - # mistralai - # openai +httpx==0.27.0 + # via openai idna==3.7 # via # anyio @@ -37,13 +35,10 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -mistralai==0.1.8 nodeenv==1.8.0 # via pyright numpy==1.26.4 openai==1.23.3 -orjson==3.10.1 - # via mistralai packaging==24.0 # via pytest pbr==6.0.0 @@ -51,9 +46,7 @@ pbr==6.0.0 pluggy==1.5.0 # via pytest pydantic==2.7.1 - # via - # mistralai - # openai + # via openai pydantic-core==2.18.2 # via pydantic pygments==2.17.2 diff --git a/pyproject.toml b/pyproject.toml index d2cb5cb..0e0eb78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ classifiers = [ license = {file = "LICENSE"} dependencies = [ "openai~=1.16", - "mistralai~=0.1", "numpy~=1.26", "tiktoken~=0.6", "pydantic~=2.6", diff --git a/src/draive/mistral/chat_response.py b/src/draive/mistral/chat_response.py index 932b34e..e9622ab 100644 --- a/src/draive/mistral/chat_response.py +++ b/src/draive/mistral/chat_response.py @@ -1,7 +1,5 @@ from typing import cast -from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage - from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage from draive.mistral.chat_tools import ( _execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage] @@ -9,6 +7,7 @@ from draive.mistral.client import MistralClient from draive.mistral.config import MistralChatConfig from draive.mistral.errors import MistralException +from draive.mistral.models import ChatCompletionResponse, ChatMessage, ChatMessageResponse from draive.scope import ctx from draive.tools import Toolbox @@ -22,7 +21,7 @@ async def _chat_response( client: MistralClient, config: MistralChatConfig, messages: list[ChatMessage], - tools: Toolbox | None, + tools: Toolbox, recursion_level: int = 0, ) -> str: if recursion_level > config.recursion_limit: @@ -32,14 +31,24 @@ async def _chat_response( "chat_response", metrics=[ArgumentsTrace.of(messages=messages.copy())], ): + suggest_tools: bool + available_tools: list[dict[str, object]] + if recursion_level == 0 and (suggested := tools.suggested_tool): + # suggest/require tool call only initially + suggest_tools = True + available_tools = cast(list[dict[str, object]], [suggested]) + else: + suggest_tools = False + available_tools = cast( + list[dict[str, object]], + tools.available_tools if tools else [], + ) + completion: ChatCompletionResponse = await client.chat_completion( config=config, messages=messages, - tools=cast( - list[dict[str, object]], - tools.available_tools if tools else [], - ), - suggest_tools=tools is not None and tools.suggested_tool_name is not None, + tools=available_tools, + suggest_tools=suggest_tools, ) if usage := completion.usage: @@ -54,7 +63,7 @@ async def _chat_response( if not completion.choices: raise MistralException("Invalid Mistral completion - missing messages!", completion) - completion_message: ChatMessage = completion.choices[0].message + completion_message: ChatMessageResponse = completion.choices[0].message if (tool_calls := completion_message.tool_calls) and (tools := tools): messages.extend( @@ -68,7 +77,7 @@ async def _chat_response( elif message := completion_message.content: ctx.record(ResultTrace.of(message)) match message: - case str() as content: + case str(content): return content # API docs say that it can be only a string in response diff --git a/src/draive/mistral/chat_stream.py b/src/draive/mistral/chat_stream.py index ebd2f12..cc192fc 100644 --- a/src/draive/mistral/chat_stream.py +++ b/src/draive/mistral/chat_stream.py @@ -1,22 +1,9 @@ -from collections.abc import AsyncIterable, AsyncIterator, Callable -from typing import cast +from collections.abc import Callable -from mistralai.models.chat_completion import ( - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, - ToolCall, -) - -from draive.metrics import ArgumentsTrace, ResultTrace from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] -from draive.mistral.chat_tools import ( - _execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage] - _flush_chat_tool_calls, # pyright: ignore[reportPrivateUsage] -) from draive.mistral.client import MistralClient from draive.mistral.config import MistralChatConfig -from draive.mistral.errors import MistralException +from draive.mistral.models import ChatMessage from draive.scope import ctx from draive.tools import Toolbox, ToolCallUpdate @@ -25,106 +12,22 @@ ] -async def _chat_stream( # noqa: C901, PLR0913 +async def _chat_stream( # noqa: PLR0913 *, client: MistralClient, config: MistralChatConfig, messages: list[ChatMessage], - tools: Toolbox | None, + tools: Toolbox, send_update: Callable[[ToolCallUpdate | str], None], recursion_level: int = 0, ) -> str: - if recursion_level > config.recursion_limit: - raise MistralException("Reached limit of recursive calls of %d", config.recursion_limit) - - if tools is not None: - ctx.log_warning( - "Mistral streaming api is broken - can't properly call tools, waiting for full response" - ) - message: str = await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools, - recursion_level=recursion_level, - ) - send_update(message) - return message - - with ctx.nested( - "chat_stream", - metrics=[ArgumentsTrace.of(messages=messages.copy())], - ): - completion_stream: AsyncIterable[ - ChatCompletionStreamResponse - ] = await client.chat_completion( - config=config, - messages=messages, - tools=cast( - list[dict[str, object]], - tools.available_tools if tools else [], - ), - suggest_tools=False, # type: ignore - no tools allowed in streaming - stream=True, - ) - completion_stream_iterator: AsyncIterator[ChatCompletionStreamResponse] = ( - completion_stream.__aiter__() - ) - - while True: # load chunks to decide what to do next - head: ChatCompletionStreamResponse - try: - head = await anext(completion_stream_iterator) - except StopAsyncIteration as exc: - # could not decide what to do before stream end - raise MistralException("Invalid Mistral completion stream") from exc - - if not head.choices: - raise MistralException("Invalid Mistral completion - missing deltas!", head) - - completion_head: DeltaMessage = head.choices[0].delta - - # TODO: record token usage - - if completion_head.tool_calls is not None and (tools := tools): - tool_calls: list[ToolCall] = await _flush_chat_tool_calls( - tool_calls=completion_head.tool_calls, - completion_stream=completion_stream_iterator, - ) - messages.extend( - await _execute_chat_tool_calls( - tool_calls=tool_calls, - tools=tools, - ) - ) - ctx.record(ResultTrace.of(tool_calls)) - break # after processing tool calls continue with recursion in outer context - - elif completion_head.content is not None: - result: str = completion_head.content - if result: # provide head / first part if not empty - send_update(result) - - async for part in completion_stream: - # we are always requesting single result - no need to take care of indices - part_text: str = part.choices[0].delta.content or "" - if not part_text: - continue # skip empty parts - result += part_text - send_update(result) - - ctx.record(ResultTrace.of(result)) - return result # we hav final result here - - else: - continue # iterate over the stream until can decide what to do or reach the end - - # recursion outside of context - return await _chat_stream( + ctx.log_warning("Mistral streaming api is not supported yet, using regular response...") + message: str = await _chat_response( client=client, config=config, messages=messages, tools=tools, - send_update=send_update, - recursion_level=recursion_level + 1, + recursion_level=recursion_level, ) + send_update(message) + return message diff --git a/src/draive/mistral/chat_tools.py b/src/draive/mistral/chat_tools.py index 2656c23..3b085a5 100644 --- a/src/draive/mistral/chat_tools.py +++ b/src/draive/mistral/chat_tools.py @@ -1,19 +1,19 @@ +import json from asyncio import gather -from collections.abc import AsyncIterable, Awaitable - -from mistralai.models.chat_completion import ChatCompletionStreamResponse, ChatMessage, ToolCall +from collections.abc import Awaitable +from typing import Any +from draive.mistral.models import ChatMessage, ToolCallResponse from draive.tools import Toolbox __all__ = [ "_execute_chat_tool_calls", - "_flush_chat_tool_calls", ] async def _execute_chat_tool_calls( *, - tool_calls: list[ToolCall], + tool_calls: list[ToolCallResponse], tools: Toolbox, ) -> list[ChatMessage]: tool_call_results: list[Awaitable[ChatMessage]] = [] @@ -26,12 +26,23 @@ async def _execute_chat_tool_calls( tools=tools, ) ) - return [ ChatMessage( role="assistant", content="", - tool_calls=tool_calls, + tool_calls=[ + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments + if isinstance(call.function.arguments, str) + else json.dumps(call.function.arguments), + }, + } + for call in tool_calls + ], ), *await gather( *tool_call_results, @@ -44,7 +55,7 @@ async def _execute_chat_tool_call( *, call_id: str, name: str, - arguments: str, + arguments: dict[str, Any] | str, tools: Toolbox, ) -> ChatMessage: try: # make sure that tool error won't blow up whole chain @@ -66,46 +77,3 @@ async def _execute_chat_tool_call( name=name, content="Error", ) - - -async def _flush_chat_tool_calls( # noqa: PLR0912 - *, - tool_calls: list[ToolCall], - completion_stream: AsyncIterable[ChatCompletionStreamResponse], -) -> list[ToolCall]: - # iterate over the stream to get full list of tool calls - async for chunk in completion_stream: - for call in chunk.choices[0].delta.tool_calls or []: - try: - tool_call: ToolCall = next( - tool_call for tool_call in tool_calls if tool_call.id == call.id - ) - - if call.id: - if tool_call.id != "null": - tool_call.id += call.id - else: - tool_call.id = call.id - else: - pass - - if call.function.name: - if tool_call.function.name: - tool_call.function.name += call.function.name - else: - tool_call.function.name = call.function.name - else: - pass - - if call.function.arguments: - if tool_call.function.arguments: - tool_call.function.arguments += call.function.arguments - else: - tool_call.function.arguments = call.function.arguments - else: - pass - - except (StopIteration, StopAsyncIteration): - tool_calls.append(call) - - return tool_calls diff --git a/src/draive/mistral/client.py b/src/draive/mistral/client.py index 46d37c2..cc6d93b 100644 --- a/src/draive/mistral/client.py +++ b/src/draive/mistral/client.py @@ -1,19 +1,22 @@ +import json from asyncio import gather from collections.abc import AsyncIterable, Iterable from itertools import chain -from typing import Literal, Self, cast, final, overload +from typing import Any, Literal, Self, cast, final, overload -from mistralai.async_client import MistralAsyncClient -from mistralai.models.chat_completion import ( +from httpx import AsyncClient, Response + +from draive.helpers import getenv_str, when_missing +from draive.mistral.config import MistralChatConfig, MistralEmbeddingConfig +from draive.mistral.errors import MistralException +from draive.mistral.models import ( ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, + EmbeddingResponse, ) -from mistralai.models.embeddings import EmbeddingResponse - -from draive.helpers import getenv_str, when_missing -from draive.mistral.config import MistralChatConfig, MistralEmbeddingConfig from draive.scope import ScopeDependency +from draive.types import Model __all__ = [ "MistralClient", @@ -25,24 +28,24 @@ class MistralClient(ScopeDependency): @classmethod def prepare(cls) -> Self: return cls( + endpoint=getenv_str("MISTRAL_ENDPOINT", "https://api.mistral.ai"), api_key=getenv_str("MISTRAL_API_KEY"), - endpoint=getenv_str("MISTRAL_ENDPOINT"), + timeout=90, ) def __init__( self, + endpoint: str, api_key: str | None, - endpoint: str | None = None, + timeout: float | None = None, ) -> None: - if endpoint: - self._client: MistralAsyncClient = MistralAsyncClient( - api_key=api_key, - endpoint=endpoint, - ) - else: - self._client: MistralAsyncClient = MistralAsyncClient( - api_key=api_key, - ) + self._client: AsyncClient = AsyncClient( + base_url=endpoint, + headers={ + "Authorization": f"Bearer {api_key}", + }, + timeout=timeout, + ) @overload async def chat_completion( @@ -85,27 +88,18 @@ async def chat_completion( # noqa: PLR0913 stream: bool = False, ) -> AsyncIterable[ChatCompletionStreamResponse] | ChatCompletionResponse: if stream: - return self._client.chat_stream( + raise NotImplementedError("Mistral streaming is not supported yet") + else: + return await self._create_chat_completion( messages=messages, model=config.model, - max_tokens=when_missing(config.max_tokens, default=None), - response_format=cast(dict[str, str], config.response_format), - random_seed=when_missing(config.seed, default=None), temperature=config.temperature, - tools=tools, top_p=when_missing(config.top_p, default=None), - ) - else: - return await self._client.chat( - messages=messages, - model=config.model, max_tokens=when_missing(config.max_tokens, default=None), response_format=cast(dict[str, str], config.response_format), - random_seed=when_missing(config.seed, default=None), - temperature=config.temperature, + seed=when_missing(config.seed, default=None), tools=tools, tool_choice=("any" if suggest_tools else "auto") if tools else None, - top_p=when_missing(config.top_p, default=None), ) async def embedding( @@ -128,16 +122,115 @@ async def embedding( ) ) + async def _create_chat_completion( # noqa: PLR0913 + self, + model: str, + temperature: float, + top_p: float | None, + seed: int | None, + max_tokens: int | None, + response_format: dict[str, str] | None, + messages: list[ChatMessage], + tools: list[dict[str, object]], + tool_choice: str | None, + ) -> ChatCompletionResponse: + request_body: dict[str, Any] = { + "model": model, + "temperature": temperature, + "messages": messages, + } + + if tools: + request_body["tools"] = tools + if tool_choice is not None: + request_body["tool_choice"] = tool_choice + elif tools: + request_body["tool_choice"] = "auto" + if max_tokens: + request_body["max_tokens"] = max_tokens + if top_p is not None: + request_body["top_p"] = top_p + if seed is not None: + request_body["random_seed"] = seed + if response_format is not None: + request_body["response_format"] = response_format + + return await self._request( + model=ChatCompletionResponse, + method="POST", + url="v1/chat/completions", + body=request_body, + ) + async def _create_text_embedding( self, model: str, texts: list[str], ) -> list[list[float]]: - response: EmbeddingResponse = await self._client.embeddings( - model=model, - input=texts, + response: EmbeddingResponse = await self._request( + model=EmbeddingResponse, + method="POST", + url="v1/embeddings", + body={ + "model": model, + "input": texts, + }, ) return [element.embedding for element in response.data] async def dispose(self) -> None: - await self._client.close() + await self._client.aclose() + + async def _request[Requested: Model]( # noqa: PLR0913 + self, + model: type[Requested], + method: str, + url: str, + query: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + body: Model | dict[str, Any] | None = None, + follow_redirects: bool | None = None, + timeout: float | None = None, + ) -> Requested: + request_headers: dict[str, str] + if headers: + request_headers = headers + else: + request_headers = { + "Accept": "application/json", + } + body_content: str | None + match body: + case None: + body_content = None + + case body_model if isinstance(body_model, Model): + body_content = body_model.as_json() + + case values: + body_content = json.dumps(values) + + if body_content: + request_headers["Content-Type"] = "application/json" + + response: Response + try: + response = await self._client.request( + method=method, + url=url, + headers=request_headers, + params=query, + content=body_content, + follow_redirects=follow_redirects or False, + timeout=timeout, + ) + except Exception as exc: + raise MistralException("Network request failed") from exc + + if response.status_code in range(200, 299): + try: + return model.from_json(await response.aread()) + except Exception as exc: + raise MistralException("Failed to decode mistral response", response) from exc + else: + raise MistralException("Network request failed") diff --git a/src/draive/mistral/lmm.py b/src/draive/mistral/lmm.py index e520c39..0d79277 100644 --- a/src/draive/mistral/lmm.py +++ b/src/draive/mistral/lmm.py @@ -1,13 +1,12 @@ from collections.abc import Callable from typing import Literal, overload -from mistralai.models.chat_completion import ChatMessage - from draive.lmm import LMMCompletionMessage, LMMCompletionStream, LMMCompletionStreamingUpdate from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage] from draive.mistral.chat_stream import _chat_stream # pyright: ignore[reportPrivateUsage] from draive.mistral.client import MistralClient from draive.mistral.config import MistralChatConfig +from draive.mistral.models import ChatMessage from draive.scope import ctx from draive.tools import Toolbox, ToolCallUpdate, ToolsUpdatesContext from draive.types import ImageBase64Content, ImageURLContent, Model @@ -73,14 +72,21 @@ async def mistral_lmm_completion( match stream: case False: with ctx.nested("mistral_lmm_completion", metrics=[config]): + message: str = await _chat_response( + client=client, + config=config, + messages=messages, + tools=tools or Toolbox(), + ) + if tools and output == "json": + # workaround for json mode with tools + # most common mistral mistake is to not escape newlines + # we can't fix it easily - code below removes all newlines + # which may cause missing newlines in the json content + message = message.replace("\n", "") return LMMCompletionMessage( role="assistant", - content=await _chat_response( - client=client, - config=config, - messages=messages, - tools=tools, - ), + content=message, ) case True: @@ -94,7 +100,7 @@ async def stream_task( metrics=[config], ): - def send_update(update: ToolCallUpdate | str): + def send_update(update: ToolCallUpdate | str) -> None: if isinstance(update, str): streaming_update( LMMCompletionMessage( @@ -109,7 +115,7 @@ def send_update(update: ToolCallUpdate | str): client=client, config=config, messages=messages, - tools=tools, + tools=tools or Toolbox(), send_update=send_update, ) @@ -117,7 +123,7 @@ def send_update(update: ToolCallUpdate | str): case streaming_update: - def send_update(update: ToolCallUpdate | str): + def send_update(update: ToolCallUpdate | str) -> None: if isinstance(update, str): streaming_update( LMMCompletionMessage( @@ -139,7 +145,7 @@ def send_update(update: ToolCallUpdate | str): client=client, config=config, messages=messages, - tools=tools, + tools=tools or Toolbox(), send_update=send_update, ), ) diff --git a/src/draive/mistral/models.py b/src/draive/mistral/models.py new file mode 100644 index 0000000..ab839e0 --- /dev/null +++ b/src/draive/mistral/models.py @@ -0,0 +1,110 @@ +from typing import Any, Literal, NotRequired, Required, TypedDict + +from draive.types import Model + +__all__ = [ + "UsageInfo", + "EmbeddingObject", + "EmbeddingResponse", + "FunctionCall", + "ToolCall", + "ChatMessage", + "FunctionCallResponse", + "ToolCallResponse", + "ChatMessageResponse", + "ChatCompletionResponseChoice", + "ChatCompletionResponse", + "ChatDeltaMessageResponse", + "ChatCompletionResponseStreamChoice", + "ChatCompletionStreamResponse", +] + + +class UsageInfo(Model): + prompt_tokens: int + total_tokens: int + completion_tokens: int | None = None + + +class EmbeddingObject(Model): + object: str + embedding: list[float] + index: int + + +class EmbeddingResponse(Model): + id: str + object: str + data: list[EmbeddingObject] + model: str + usage: UsageInfo + + +class FunctionCall(TypedDict, total=False): + name: Required[str] + arguments: Required[str] + + +class ToolCall(TypedDict, total=False): + id: Required[str] + type: Required[Literal["function"]] + function: Required[FunctionCall] + + +class ChatMessage(TypedDict, total=False): + role: Required[str] + content: Required[str | list[str]] + name: NotRequired[str] + tool_calls: NotRequired[list[ToolCall]] + + +class FunctionCallResponse(Model): + name: str + arguments: dict[str, Any] | str + + +class ToolCallResponse(Model): + id: str + type: Literal["function"] + function: FunctionCallResponse + + +class ChatDeltaMessageResponse(Model): + role: str | None = None + content: str | None = None + tool_calls: list[ToolCallResponse] | None = None + + +class ChatCompletionResponseStreamChoice(Model): + index: int + delta: ChatDeltaMessageResponse + finish_reason: Literal["stop", "length", "error", "tool_calls"] | None = None + + +class ChatCompletionStreamResponse(Model): + id: str + model: str + choices: list[ChatCompletionResponseStreamChoice] + created: int | None = None + usage: UsageInfo | None = None + + +class ChatMessageResponse(Model): + role: str + content: list[str] | str | None = None + tool_calls: list[ToolCallResponse] | None = None + + +class ChatCompletionResponseChoice(Model): + index: int + message: ChatMessageResponse + finish_reason: Literal["stop", "length", "error", "tool_calls"] | None = None + + +class ChatCompletionResponse(Model): + id: str + object: str + created: int + model: str + choices: list[ChatCompletionResponseChoice] + usage: UsageInfo diff --git a/src/draive/openai/chat_response.py b/src/draive/openai/chat_response.py index b3a44f1..8e484f0 100644 --- a/src/draive/openai/chat_response.py +++ b/src/draive/openai/chat_response.py @@ -27,7 +27,7 @@ async def _chat_response( client: OpenAIClient, config: OpenAIChatConfig, messages: list[ChatCompletionMessageParam], - tools: Toolbox | None, + tools: Toolbox, recursion_level: int = 0, ) -> str: if recursion_level > config.recursion_limit: @@ -50,7 +50,7 @@ async def _chat_response( "name": tools.suggested_tool_name, }, } # suggest/require tool call only initially - if recursion_level == 0 and tools is not None and tools.suggested_tool_name + if recursion_level == 0 and tools.suggested_tool_name else None, ) diff --git a/src/draive/openai/chat_stream.py b/src/draive/openai/chat_stream.py index 56a79e0..fd7262e 100644 --- a/src/draive/openai/chat_stream.py +++ b/src/draive/openai/chat_stream.py @@ -31,7 +31,7 @@ async def _chat_stream( # noqa: PLR0913 client: OpenAIClient, config: OpenAIChatConfig, messages: list[ChatCompletionMessageParam], - tools: Toolbox | None, + tools: Toolbox, send_update: Callable[[ToolCallUpdate | str], None], recursion_level: int = 0, ) -> str: @@ -55,7 +55,7 @@ async def _chat_stream( # noqa: PLR0913 "name": tools.suggested_tool_name, }, } # suggest/require tool call only initially - if recursion_level == 0 and tools is not None and tools.suggested_tool_name + if recursion_level == 0 and tools.suggested_tool_name else None, stream=True, ) diff --git a/src/draive/openai/lmm.py b/src/draive/openai/lmm.py index 2f38e2c..f82ed6d 100644 --- a/src/draive/openai/lmm.py +++ b/src/draive/openai/lmm.py @@ -73,7 +73,7 @@ async def openai_lmm_completion( client=client, config=config, messages=messages, - tools=tools, + tools=tools or Toolbox(), ), ) @@ -103,7 +103,7 @@ def send_update(update: ToolCallUpdate | str) -> None: client=client, config=config, messages=messages, - tools=tools, + tools=tools or Toolbox(), send_update=send_update, ) @@ -133,7 +133,7 @@ def send_update(update: ToolCallUpdate | str) -> None: client=client, config=config, messages=messages, - tools=tools, + tools=tools or Toolbox(), send_update=send_update, ), ) diff --git a/src/draive/parameters/validation.py b/src/draive/parameters/validation.py index c5c148a..bbe1d61 100644 --- a/src/draive/parameters/validation.py +++ b/src/draive/parameters/validation.py @@ -819,6 +819,9 @@ def parameter_validator[Value]( # noqa: PLR0911, C901 case types.NoneType | None: return _none_validator(verifier=verifier) + case typing.Any: + return _any_validator(verifier=verifier) + case types.UnionType | typing.Union: return _union_validator( alternatives=get_args(annotation), @@ -1011,8 +1014,5 @@ def validated_missing(value: Any) -> Any: verifier=verifier, ) - case typing.Any: - return _any_validator(verifier=verifier) - case other: raise TypeError("Unsupported type annotation %s", other) diff --git a/src/draive/splitters/basic.py b/src/draive/splitters/basic.py index 22ccd99..f6ed8ea 100644 --- a/src/draive/splitters/basic.py +++ b/src/draive/splitters/basic.py @@ -39,10 +39,10 @@ def _split( splitter: str alt_splitter: str match separators: - case (str() as primary, str() as secondary): + case (str(primary), str(secondary)): splitter = primary alt_splitter = secondary - case str() as primary: + case str(primary): splitter = primary alt_splitter = "\n" case None: diff --git a/src/draive/tools/toolbox.py b/src/draive/tools/toolbox.py index a12962a..0ac9f49 100644 --- a/src/draive/tools/toolbox.py +++ b/src/draive/tools/toolbox.py @@ -26,10 +26,17 @@ def __init__( @property def suggested_tool_name(self) -> str | None: - if self._suggested_tool is None or not self._suggested_tool.available: - return None - elif self._suggested_tool.available: + if self._suggested_tool is not None and self._suggested_tool.available: return self._suggested_tool.name + else: + return None + + @property + def suggested_tool(self) -> ToolSpecification | None: + if self._suggested_tool is not None and self._suggested_tool.available: + return self._suggested_tool.specification + else: + return None @property def available_tools(self) -> list[ToolSpecification]: @@ -40,12 +47,12 @@ async def call_tool( name: str, /, call_id: str, - arguments: str | bytes | None, + arguments: dict[str, Any] | str | bytes | None, ) -> Any: if tool := self._tools[name]: return await tool( tool_call_id=call_id, - **loads(arguments) if arguments else {}, + **loads(arguments) if isinstance(arguments, str | bytes) else arguments or {}, ) else: raise ToolException("Requested tool is not defined", name)