Skip to content

Commit

Permalink
Remove mistral dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed Apr 25, 2024
1 parent 36e868f commit c41e13e
Show file tree
Hide file tree
Showing 14 changed files with 330 additions and 242 deletions.
13 changes: 3 additions & 10 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ h11==0.14.0
# via httpcore
httpcore==1.0.5
# via httpx
httpx==0.25.2
# via
# mistralai
# openai
httpx==0.27.0
# via openai
idna==3.7
# via
# anyio
Expand All @@ -37,23 +35,18 @@ markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
mistralai==0.1.8
nodeenv==1.8.0
# via pyright
numpy==1.26.4
openai==1.23.3
orjson==3.10.1
# via mistralai
packaging==24.0
# via pytest
pbr==6.0.0
# via stevedore
pluggy==1.5.0
# via pytest
pydantic==2.7.1
# via
# mistralai
# openai
# via openai
pydantic-core==2.18.2
# via pydantic
pygments==2.17.2
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ classifiers = [
license = {file = "LICENSE"}
dependencies = [
"openai~=1.16",
"mistralai~=0.1",
"numpy~=1.26",
"tiktoken~=0.6",
"pydantic~=2.6",
Expand Down
29 changes: 19 additions & 10 deletions src/draive/mistral/chat_response.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import cast

from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage

from draive.metrics import ArgumentsTrace, ResultTrace, TokenUsage
from draive.mistral.chat_tools import (
_execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage]
)
from draive.mistral.client import MistralClient
from draive.mistral.config import MistralChatConfig
from draive.mistral.errors import MistralException
from draive.mistral.models import ChatCompletionResponse, ChatMessage, ChatMessageResponse
from draive.scope import ctx
from draive.tools import Toolbox

Expand All @@ -22,7 +21,7 @@ async def _chat_response(
client: MistralClient,
config: MistralChatConfig,
messages: list[ChatMessage],
tools: Toolbox | None,
tools: Toolbox,
recursion_level: int = 0,
) -> str:
if recursion_level > config.recursion_limit:
Expand All @@ -32,14 +31,24 @@ async def _chat_response(
"chat_response",
metrics=[ArgumentsTrace.of(messages=messages.copy())],
):
suggest_tools: bool
available_tools: list[dict[str, object]]
if recursion_level == 0 and (suggested := tools.suggested_tool):
# suggest/require tool call only initially
suggest_tools = True
available_tools = cast(list[dict[str, object]], [suggested])
else:
suggest_tools = False
available_tools = cast(
list[dict[str, object]],
tools.available_tools if tools else [],
)

completion: ChatCompletionResponse = await client.chat_completion(
config=config,
messages=messages,
tools=cast(
list[dict[str, object]],
tools.available_tools if tools else [],
),
suggest_tools=tools is not None and tools.suggested_tool_name is not None,
tools=available_tools,
suggest_tools=suggest_tools,
)

if usage := completion.usage:
Expand All @@ -54,7 +63,7 @@ async def _chat_response(
if not completion.choices:
raise MistralException("Invalid Mistral completion - missing messages!", completion)

completion_message: ChatMessage = completion.choices[0].message
completion_message: ChatMessageResponse = completion.choices[0].message

if (tool_calls := completion_message.tool_calls) and (tools := tools):
messages.extend(
Expand All @@ -68,7 +77,7 @@ async def _chat_response(
elif message := completion_message.content:
ctx.record(ResultTrace.of(message))
match message:
case str() as content:
case str(content):
return content

# API docs say that it can be only a string in response
Expand Down
115 changes: 9 additions & 106 deletions src/draive/mistral/chat_stream.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
from collections.abc import AsyncIterable, AsyncIterator, Callable
from typing import cast
from collections.abc import Callable

from mistralai.models.chat_completion import (
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ToolCall,
)

from draive.metrics import ArgumentsTrace, ResultTrace
from draive.mistral.chat_response import _chat_response # pyright: ignore[reportPrivateUsage]
from draive.mistral.chat_tools import (
_execute_chat_tool_calls, # pyright: ignore[reportPrivateUsage]
_flush_chat_tool_calls, # pyright: ignore[reportPrivateUsage]
)
from draive.mistral.client import MistralClient
from draive.mistral.config import MistralChatConfig
from draive.mistral.errors import MistralException
from draive.mistral.models import ChatMessage
from draive.scope import ctx
from draive.tools import Toolbox, ToolCallUpdate

Expand All @@ -25,106 +12,22 @@
]


async def _chat_stream( # noqa: C901, PLR0913
async def _chat_stream( # noqa: PLR0913
*,
client: MistralClient,
config: MistralChatConfig,
messages: list[ChatMessage],
tools: Toolbox | None,
tools: Toolbox,
send_update: Callable[[ToolCallUpdate | str], None],
recursion_level: int = 0,
) -> str:
if recursion_level > config.recursion_limit:
raise MistralException("Reached limit of recursive calls of %d", config.recursion_limit)

if tools is not None:
ctx.log_warning(
"Mistral streaming api is broken - can't properly call tools, waiting for full response"
)
message: str = await _chat_response(
client=client,
config=config,
messages=messages,
tools=tools,
recursion_level=recursion_level,
)
send_update(message)
return message

with ctx.nested(
"chat_stream",
metrics=[ArgumentsTrace.of(messages=messages.copy())],
):
completion_stream: AsyncIterable[
ChatCompletionStreamResponse
] = await client.chat_completion(
config=config,
messages=messages,
tools=cast(
list[dict[str, object]],
tools.available_tools if tools else [],
),
suggest_tools=False, # type: ignore - no tools allowed in streaming
stream=True,
)
completion_stream_iterator: AsyncIterator[ChatCompletionStreamResponse] = (
completion_stream.__aiter__()
)

while True: # load chunks to decide what to do next
head: ChatCompletionStreamResponse
try:
head = await anext(completion_stream_iterator)
except StopAsyncIteration as exc:
# could not decide what to do before stream end
raise MistralException("Invalid Mistral completion stream") from exc

if not head.choices:
raise MistralException("Invalid Mistral completion - missing deltas!", head)

completion_head: DeltaMessage = head.choices[0].delta

# TODO: record token usage

if completion_head.tool_calls is not None and (tools := tools):
tool_calls: list[ToolCall] = await _flush_chat_tool_calls(
tool_calls=completion_head.tool_calls,
completion_stream=completion_stream_iterator,
)
messages.extend(
await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)
)
ctx.record(ResultTrace.of(tool_calls))
break # after processing tool calls continue with recursion in outer context

elif completion_head.content is not None:
result: str = completion_head.content
if result: # provide head / first part if not empty
send_update(result)

async for part in completion_stream:
# we are always requesting single result - no need to take care of indices
part_text: str = part.choices[0].delta.content or ""
if not part_text:
continue # skip empty parts
result += part_text
send_update(result)

ctx.record(ResultTrace.of(result))
return result # we hav final result here

else:
continue # iterate over the stream until can decide what to do or reach the end

# recursion outside of context
return await _chat_stream(
ctx.log_warning("Mistral streaming api is not supported yet, using regular response...")
message: str = await _chat_response(
client=client,
config=config,
messages=messages,
tools=tools,
send_update=send_update,
recursion_level=recursion_level + 1,
recursion_level=recursion_level,
)
send_update(message)
return message
70 changes: 19 additions & 51 deletions src/draive/mistral/chat_tools.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
from asyncio import gather
from collections.abc import AsyncIterable, Awaitable

from mistralai.models.chat_completion import ChatCompletionStreamResponse, ChatMessage, ToolCall
from collections.abc import Awaitable
from typing import Any

from draive.mistral.models import ChatMessage, ToolCallResponse
from draive.tools import Toolbox

__all__ = [
"_execute_chat_tool_calls",
"_flush_chat_tool_calls",
]


async def _execute_chat_tool_calls(
*,
tool_calls: list[ToolCall],
tool_calls: list[ToolCallResponse],
tools: Toolbox,
) -> list[ChatMessage]:
tool_call_results: list[Awaitable[ChatMessage]] = []
Expand All @@ -26,12 +26,23 @@ async def _execute_chat_tool_calls(
tools=tools,
)
)

return [
ChatMessage(
role="assistant",
content="",
tool_calls=tool_calls,
tool_calls=[
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments
if isinstance(call.function.arguments, str)
else json.dumps(call.function.arguments),
},
}
for call in tool_calls
],
),
*await gather(
*tool_call_results,
Expand All @@ -44,7 +55,7 @@ async def _execute_chat_tool_call(
*,
call_id: str,
name: str,
arguments: str,
arguments: dict[str, Any] | str,
tools: Toolbox,
) -> ChatMessage:
try: # make sure that tool error won't blow up whole chain
Expand All @@ -66,46 +77,3 @@ async def _execute_chat_tool_call(
name=name,
content="Error",
)


async def _flush_chat_tool_calls( # noqa: PLR0912
*,
tool_calls: list[ToolCall],
completion_stream: AsyncIterable[ChatCompletionStreamResponse],
) -> list[ToolCall]:
# iterate over the stream to get full list of tool calls
async for chunk in completion_stream:
for call in chunk.choices[0].delta.tool_calls or []:
try:
tool_call: ToolCall = next(
tool_call for tool_call in tool_calls if tool_call.id == call.id
)

if call.id:
if tool_call.id != "null":
tool_call.id += call.id
else:
tool_call.id = call.id
else:
pass

if call.function.name:
if tool_call.function.name:
tool_call.function.name += call.function.name
else:
tool_call.function.name = call.function.name
else:
pass

if call.function.arguments:
if tool_call.function.arguments:
tool_call.function.arguments += call.function.arguments
else:
tool_call.function.arguments = call.function.arguments
else:
pass

except (StopIteration, StopAsyncIteration):
tool_calls.append(call)

return tool_calls
Loading

0 comments on commit c41e13e

Please sign in to comment.