Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 135 additions & 66 deletions pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from base64 import b64decode
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import (
Expand All @@ -12,28 +13,37 @@

from ... import ExternalToolset, ToolDefinition
from ...messages import (
AudioUrl,
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
DocumentUrl,
ImageUrl,
ModelMessage,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
VideoUrl,
)
from ...output import OutputDataT
from ...tools import AgentDepsT
from ...toolsets import AbstractToolset

try:
from ag_ui.core import (
ActivityMessage,
AssistantMessage,
BaseEvent,
BinaryInputContent,
DeveloperMessage,
Message,
RunAgentInput,
SystemMessage,
TextInputContent,
Tool as AGUITool,
ToolCall,
ToolMessage,
UserMessage,
)
Expand Down Expand Up @@ -124,72 +134,131 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]:
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.

for msg in messages:
if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or (
isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
):
if isinstance(msg, UserMessage):
builder.add(UserPromptPart(content=msg.content))
elif isinstance(msg, SystemMessage | DeveloperMessage):
builder.add(SystemPromptPart(content=msg.content))
else:
tool_call_id = msg.tool_call_id
tool_name = tool_calls.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')

builder.add(
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
)
)
match msg:
case UserMessage(content=content):
if isinstance(content, str):
builder.add(UserPromptPart(content=content))
else:
user_prompt_content: list[Any] = []
for part in content:
match part:
case TextInputContent(text=text):
user_prompt_content.append(text)
case BinaryInputContent():
user_prompt_content.append(cls.load_binary_part(part))
case _:
raise ValueError(f'Unsupported user message part type: {type(part)}')

elif isinstance(msg, AssistantMessage) or ( # pragma: no branch
isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX)
):
if isinstance(msg, AssistantMessage):
if msg.content:
builder.add(TextPart(content=msg.content))

if msg.tool_calls:
for tool_call in msg.tool_calls:
tool_call_id = tool_call.id
tool_name = tool_call.function.name
tool_calls[tool_call_id] = tool_name

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, tool_call_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolCallPart(
tool_name=tool_name,
args=tool_call.function.arguments,
tool_call_id=tool_call_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolCallPart(
tool_name=tool_name,
tool_call_id=tool_call_id,
args=tool_call.function.arguments,
)
)
else:
tool_call_id = msg.tool_call_id
tool_name = tool_calls.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')
_, provider_name, tool_call_id = tool_call_id.split('|', 2)

builder.add(
BuiltinToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
provider_name=provider_name,
)
)
if user_prompt_content:
content_to_add = (
user_prompt_content[0]
if len(user_prompt_content) == 1 and isinstance(user_prompt_content[0], str)
else user_prompt_content
)
builder.add(UserPromptPart(content=content_to_add))

case SystemMessage(content=content) | DeveloperMessage(content=content):
builder.add(SystemPromptPart(content=content))

case AssistantMessage(content=content, tool_calls=tool_calls_list):
if content:
builder.add(TextPart(content=content))
if tool_calls_list:
cls.add_assistant_tool_parts(builder, tool_calls_list, tool_calls)

case ToolMessage() as tool_msg:
cls.add_tool_return_part(builder, tool_msg, tool_calls)

case ActivityMessage(content=content):
# No matching on the Pydantic AI side.
pass

return builder.messages

@classmethod
def load_binary_part(cls, part: BinaryInputContent) -> BinaryContent | ImageUrl | VideoUrl | AudioUrl | DocumentUrl:
"""Transforms an AG-UI BinaryInputContent part into a Pydantic AI content part."""
if part.url:
try:
return BinaryContent.from_data_uri(part.url)
except ValueError:
media_type_constructors = {
'image': ImageUrl,
'video': VideoUrl,
'audio': AudioUrl,
}
media_type_prefix = part.mime_type.split('/', 1)[0]
constructor = media_type_constructors.get(media_type_prefix, DocumentUrl)
return constructor(
url=part.url,
media_type=part.mime_type,
identifier=part.id,
)
if part.data:
return BinaryContent(data=b64decode(part.data), kind='binary', media_type=part.mime_type)

raise ValueError('BinaryInputContent must have either a `url` or `data` field.')

@classmethod
def add_assistant_tool_parts(
cls,
builder: MessagesBuilder,
tool_calls_list: list[ToolCall],
tool_calls_map: dict[str, str],
) -> None:
"""Adds tool call parts from an AssistantMessage to the builder."""
for tool_call in tool_calls_list:
tool_call_id = tool_call.id
tool_name = tool_call.function.name
tool_calls_map[tool_call_id] = tool_name

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, original_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolCallPart(
tool_name=tool_name,
args=tool_call.function.arguments,
tool_call_id=original_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolCallPart(
tool_name=tool_name,
tool_call_id=tool_call_id,
args=tool_call.function.arguments,
)
)

@classmethod
def add_tool_return_part(
cls,
builder: MessagesBuilder,
msg: ToolMessage,
tool_calls_map: dict[str, str],
) -> None:
"""Adds a tool return part from a ToolMessage to the builder."""
tool_call_id = msg.tool_call_id
tool_name = tool_calls_map.get(tool_call_id)
if tool_name is None: # pragma: no cover
raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.')

if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX):
_, provider_name, original_id = tool_call_id.split('|', 2)
builder.add(
BuiltinToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=original_id,
provider_name=provider_name,
)
)
else:
builder.add(
ToolReturnPart(
tool_name=tool_name,
content=msg.content,
tool_call_id=tool_call_id,
)
)
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ui = ["starlette>=0.45.3"]
# A2A
a2a = ["fasta2a>=0.4.1"]
# AG-UI
ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
ag-ui = ["ag-ui-protocol>=0.1.10", "starlette>=0.45.3"]
# Retries
retries = ["tenacity>=8.2.3"]
# Temporal
Expand Down
Loading
Loading