Skip to content

Commit

Permalink
Add tool direct result
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed Apr 26, 2024
1 parent ec982c0 commit 56c5d44
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 143 deletions.
2 changes: 0 additions & 2 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@
"AsyncStream",
"AsyncStreamTask",
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
Expand Down
81 changes: 52 additions & 29 deletions src/draive/helpers/mimic.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,67 @@
from collections.abc import Callable
from typing import Any, cast
from typing import Any, cast, overload

__all__ = [
"mimic_function",
]


@overload
def mimic_function[**Args, Result](
function: Callable[Args, Result],
/,
within: Callable[..., Any],
) -> Callable[Args, Result]:
# mimic function attributes if able
for attribute in [
"__module__",
"__name__",
"__qualname__",
"__annotations__",
"__defaults__",
"__kwdefaults__",
"__doc__",
]:
try:
setattr(
within,
attribute,
getattr(
function,
) -> Callable[Args, Result]: ...


@overload
def mimic_function[**Args, Result](
function: Callable[Args, Result],
/,
) -> Callable[[Callable[..., Any]], Callable[Args, Result]]: ...


def mimic_function[**Args, Result](
function: Callable[Args, Result],
/,
within: Callable[..., Result] | None = None,
) -> Callable[[Callable[..., Result]], Callable[Args, Result]] | Callable[Args, Result]:
def mimic(
target: Callable[..., Result],
) -> Callable[Args, Result]:
# mimic function attributes if able
for attribute in [
"__module__",
"__name__",
"__qualname__",
"__annotations__",
"__defaults__",
"__kwdefaults__",
"__doc__",
]:
try:
setattr(
target,
attribute,
),
)
getattr(
function,
attribute,
),
)

except AttributeError:
pass
try:
target.__dict__.update(function.__dict__)
except AttributeError:
pass
try:
within.__dict__.update(function.__dict__)
except AttributeError:
pass

return cast(
Callable[Args, Result],
within,
)

return cast(
Callable[Args, Result],
target,
)

if target := within:
return mimic(target)
else:
return mimic
16 changes: 10 additions & 6 deletions src/draive/mistral/chat_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,18 @@ async def _chat_response(
completion_message: ChatMessageResponse = completion.choices[0].message

if (tool_calls := completion_message.tool_calls) and (tools := tools):
messages.extend(
await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)
)
ctx.record(ResultTrace.of(tool_calls))

tools_result: list[ChatMessage] | str = await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)

if isinstance(tools_result, str):
return tools_result
else:
messages.extend(tools_result)

elif message := completion_message.content:
ctx.record(ResultTrace.of(message))
match message:
Expand Down
137 changes: 95 additions & 42 deletions src/draive/mistral/chat_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from asyncio import gather
from collections.abc import Awaitable
from typing import Any
from typing import Any, Literal, overload

from draive.mistral.models import ChatMessage, ToolCallResponse
from draive.tools import Toolbox
Expand All @@ -15,65 +15,118 @@ async def _execute_chat_tool_calls(
*,
tool_calls: list[ToolCallResponse],
tools: Toolbox,
) -> list[ChatMessage]:
) -> list[ChatMessage] | str:
direct_result: Awaitable[str] | None = None
tool_call_results: list[Awaitable[ChatMessage]] = []
for call in tool_calls:
tool_call_results.append(
_execute_chat_tool_call(
# use only the first "direct result tool" requested, can't return more than one anyways
# despite of that all tools will be called to ensure that all desired actions were executed
if direct_result is None and tools.requires_direct_result(tool_name=call.function.name):
direct_result = _execute_chat_tool_call(
call_id=call.id,
name=call.function.name,
arguments=call.function.arguments,
tools=tools,
message_result=False,
)
)
return [
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments
if isinstance(call.function.arguments, str)
else json.dumps(call.function.arguments),
},
}
for call in tool_calls
],
),
*await gather(
else:
tool_call_results.append(
_execute_chat_tool_call(
call_id=call.id,
name=call.function.name,
arguments=call.function.arguments,
tools=tools,
message_result=True,
)
)
if direct_result is not None:
results: tuple[str, ...] = await gather(
direct_result,
*tool_call_results,
return_exceptions=False,
),
]
)
return results[0] # return only the requested direct result
else:
return [
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": call.id,
"type": "function",
"function": {
"name": call.function.name,
"arguments": call.function.arguments
if isinstance(call.function.arguments, str)
else json.dumps(call.function.arguments),
},
}
for call in tool_calls
],
),
*await gather(
*tool_call_results,
return_exceptions=False,
),
]


@overload
async def _execute_chat_tool_call(
*,
call_id: str,
name: str,
arguments: dict[str, Any] | str,
tools: Toolbox,
) -> ChatMessage:
message_result: Literal[True],
) -> ChatMessage: ...


@overload
async def _execute_chat_tool_call(
*,
call_id: str,
name: str,
arguments: dict[str, Any] | str,
tools: Toolbox,
message_result: Literal[False],
) -> str: ...


async def _execute_chat_tool_call(
*,
call_id: str,
name: str,
arguments: dict[str, Any] | str,
tools: Toolbox,
message_result: bool,
) -> ChatMessage | str:
try: # make sure that tool error won't blow up whole chain
result = await tools.call_tool(
name,
call_id=call_id,
arguments=arguments,
)
return ChatMessage(
role="tool",
name=name,
content=str(result),
result: str = str(
await tools.call_tool(
name,
call_id=call_id,
arguments=arguments,
)
)
if message_result:
return ChatMessage(
role="tool",
name=name,
content=str(result),
)
else:
return result

# error should be already logged by ScopeContext
except BaseException:
return ChatMessage(
role="tool",
name=name,
content="Error",
)
except BaseException as exc:
if message_result:
return ChatMessage(
role="tool",
name=name,
content="Error",
)

else: # TODO: think about allowing the error chat message
raise exc
16 changes: 10 additions & 6 deletions src/draive/openai/chat_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@ async def _chat_response(
completion_message: ChatCompletionMessage = completion.choices[0].message

if (tool_calls := completion_message.tool_calls) and (tools := tools):
messages.extend(
await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)
)
ctx.record(ResultTrace.of(tool_calls))

tools_result: list[ChatCompletionMessageParam] | str = await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)

if isinstance(tools_result, str):
return tools_result
else:
messages.extend(tools_result)

elif message := completion_message.content:
ctx.record(ResultTrace.of(message))
return message
Expand Down
22 changes: 15 additions & 7 deletions src/draive/openai/chat_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
]


async def _chat_stream( # noqa: PLR0913
async def _chat_stream( # noqa: PLR0913, C901
*,
client: OpenAIClient,
config: OpenAIChatConfig,
Expand Down Expand Up @@ -82,13 +82,21 @@ async def _chat_stream( # noqa: PLR0913
tool_calls=completion_head.tool_calls,
completion_stream=completion_stream,
)
messages.extend(
await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)
)
ctx.record(ResultTrace.of(tool_calls))

tools_result: (
list[ChatCompletionMessageParam] | str
) = await _execute_chat_tool_calls(
tool_calls=tool_calls,
tools=tools,
)

if isinstance(tools_result, str):
send_update(tools_result)
return tools_result
else:
messages.extend(tools_result)

break # after processing tool calls continue with recursion in outer context

elif completion_head.content is not None:
Expand Down
Loading

0 comments on commit 56c5d44

Please sign in to comment.