diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index fe3513ae58..2d5d6bd097 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -2,6 +2,7 @@ from __future__ import annotations +from base64 import b64decode from collections.abc import Mapping, Sequence from functools import cached_property from typing import ( @@ -12,14 +13,19 @@ from ... import ExternalToolset, ToolDefinition from ...messages import ( + AudioUrl, + BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + DocumentUrl, + ImageUrl, ModelMessage, SystemPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, + VideoUrl, ) from ...output import OutputDataT from ...tools import AgentDepsT @@ -27,13 +33,17 @@ try: from ag_ui.core import ( + ActivityMessage, AssistantMessage, BaseEvent, + BinaryInputContent, DeveloperMessage, Message, RunAgentInput, SystemMessage, + TextInputContent, Tool as AGUITool, + ToolCall, ToolMessage, UserMessage, ) @@ -124,72 +134,130 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping. for msg in messages: - if isinstance(msg, UserMessage | SystemMessage | DeveloperMessage) or ( - isinstance(msg, ToolMessage) and not msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX) - ): - if isinstance(msg, UserMessage): - builder.add(UserPromptPart(content=msg.content)) - elif isinstance(msg, SystemMessage | DeveloperMessage): - builder.add(SystemPromptPart(content=msg.content)) - else: - tool_call_id = msg.tool_call_id - tool_name = tool_calls.get(tool_call_id) - if tool_name is None: # pragma: no cover - raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.') - - builder.add( - ToolReturnPart( - tool_name=tool_name, - content=msg.content, - tool_call_id=tool_call_id, - ) - ) + match msg: + case UserMessage(content=content): + if isinstance(content, str): + builder.add(UserPromptPart(content=content)) + else: + user_prompt_content: list[Any] = [] + for part in content: + match part: + case TextInputContent(text=text): + user_prompt_content.append(text) + case BinaryInputContent(): + user_prompt_content.append(cls.load_binary_part(part)) + case _: + raise ValueError(f'Unsupported user message part type: {type(part)}') - elif isinstance(msg, AssistantMessage) or ( # pragma: no branch - isinstance(msg, ToolMessage) and msg.tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX) - ): - if isinstance(msg, AssistantMessage): - if msg.content: - builder.add(TextPart(content=msg.content)) - - if msg.tool_calls: - for tool_call in msg.tool_calls: - tool_call_id = tool_call.id - tool_name = tool_call.function.name - tool_calls[tool_call_id] = tool_name - - if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX): - _, provider_name, tool_call_id = tool_call_id.split('|', 2) - builder.add( - BuiltinToolCallPart( - tool_name=tool_name, - args=tool_call.function.arguments, - tool_call_id=tool_call_id, - provider_name=provider_name, - ) - ) - else: - builder.add( - ToolCallPart( - tool_name=tool_name, - tool_call_id=tool_call_id, - args=tool_call.function.arguments, - ) - ) - else: - tool_call_id = msg.tool_call_id - tool_name = tool_calls.get(tool_call_id) - if tool_name is None: # pragma: no cover - raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.') - _, provider_name, tool_call_id = tool_call_id.split('|', 2) - - builder.add( - BuiltinToolReturnPart( - tool_name=tool_name, - content=msg.content, - tool_call_id=tool_call_id, - provider_name=provider_name, - ) - ) + if user_prompt_content: + content_to_add = ( + user_prompt_content[0] + if len(user_prompt_content) == 1 and isinstance(user_prompt_content[0], str) + else user_prompt_content + ) + builder.add(UserPromptPart(content=content_to_add)) + + case SystemMessage(content=content) | DeveloperMessage(content=content): + builder.add(SystemPromptPart(content=content)) + + case AssistantMessage(content=content, tool_calls=tool_calls_list): + if content: + builder.add(TextPart(content=content)) + if tool_calls_list: + cls.add_assistant_tool_parts(builder, tool_calls_list, tool_calls) + + case ToolMessage() as tool_msg: + cls.add_tool_return_part(builder, tool_msg, tool_calls) + + case ActivityMessage(): + raise ValueError(f'Unsupported message type: {type(msg)}') return builder.messages + + @classmethod + def load_binary_part(cls, part: BinaryInputContent) -> BinaryContent | ImageUrl | VideoUrl | AudioUrl | DocumentUrl: + """Transforms an AG-UI BinaryInputContent part into a Pydantic AI content part.""" + if part.url: + try: + return BinaryContent.from_data_uri(part.url) + except ValueError: + media_type_constructors = { + 'image': ImageUrl, + 'video': VideoUrl, + 'audio': AudioUrl, + } + media_type_prefix = part.mime_type.split('/', 1)[0] + constructor = media_type_constructors.get(media_type_prefix, DocumentUrl) + return constructor( + url=part.url, + media_type=part.mime_type, + identifier=part.id, + ) + if part.data: + return BinaryContent(data=b64decode(part.data), kind='binary', media_type=part.mime_type) + + raise ValueError('BinaryInputContent must have either a `url` or `data` field.') + + @classmethod + def add_assistant_tool_parts( + cls, + builder: MessagesBuilder, + tool_calls_list: list[ToolCall], + tool_calls_map: dict[str, str], + ) -> None: + """Adds tool call parts from an AssistantMessage to the builder.""" + for tool_call in tool_calls_list: + tool_call_id = tool_call.id + tool_name = tool_call.function.name + tool_calls_map[tool_call_id] = tool_name + + if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX): + _, provider_name, original_id = tool_call_id.split('|', 2) + builder.add( + BuiltinToolCallPart( + tool_name=tool_name, + args=tool_call.function.arguments, + tool_call_id=original_id, + provider_name=provider_name, + ) + ) + else: + builder.add( + ToolCallPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + args=tool_call.function.arguments, + ) + ) + + @classmethod + def add_tool_return_part( + cls, + builder: MessagesBuilder, + msg: ToolMessage, + tool_calls_map: dict[str, str], + ) -> None: + """Adds a tool return part from a ToolMessage to the builder.""" + tool_call_id = msg.tool_call_id + tool_name = tool_calls_map.get(tool_call_id) + if tool_name is None: # pragma: no cover + raise ValueError(f'Tool call with ID {tool_call_id} not found in the history.') + + if tool_call_id.startswith(BUILTIN_TOOL_CALL_ID_PREFIX): + _, provider_name, original_id = tool_call_id.split('|', 2) + builder.add( + BuiltinToolReturnPart( + tool_name=tool_name, + content=msg.content, + tool_call_id=original_id, + provider_name=provider_name, + ) + ) + else: + builder.add( + ToolReturnPart( + tool_name=tool_name, + content=msg.content, + tool_call_id=tool_call_id, + ) + ) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index c59f6a8275..98e6aaa96f 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -103,7 +103,7 @@ ui = ["starlette>=0.45.3"] # A2A a2a = ["fasta2a>=0.4.1"] # AG-UI -ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +ag-ui = ["ag-ui-protocol>=0.1.10", "starlette>=0.45.3"] # Retries retries = ["tenacity>=8.2.3"] # Temporal diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 5cbf85fc69..1672a2429f 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -4,6 +4,7 @@ import json import uuid +from base64 import b64decode from collections.abc import AsyncIterator, MutableMapping from dataclasses import dataclass from http import HTTPStatus @@ -17,10 +18,14 @@ from pydantic import BaseModel from pydantic_ai import ( + AudioUrl, + BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + DocumentUrl, FunctionToolCallEvent, FunctionToolResultEvent, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, @@ -35,6 +40,7 @@ ToolReturn, ToolReturnPart, UserPromptPart, + VideoUrl, ) from pydantic_ai._run_context import RunContext from pydantic_ai.agent import Agent, AgentRunResult @@ -51,13 +57,17 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputDataT from pydantic_ai.tools import AgentDepsT, ToolDefinition +from pydantic_ai.ui import MessagesBuilder +from pydantic_ai.ui.ag_ui._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX from .conftest import IsDatetime, IsSameStr, try_import with try_import() as imports_successful: from ag_ui.core import ( + ActivityMessage, AssistantMessage, BaseEvent, + BinaryInputContent, CustomEvent, DeveloperMessage, EventType, @@ -66,6 +76,7 @@ RunAgentInput, StateSnapshotEvent, SystemMessage, + TextInputContent, Tool, ToolCall, ToolMessage, @@ -264,6 +275,24 @@ async def test_basic_user_message() -> None: assert events == simple_result() +async def test_complex_user_message() -> None: + """Test basic user message with text response. But using TextInputContent instead of str""" + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + + run_input = create_input( + UserMessage( + id='msg_1', + content=[TextInputContent(text='Hello, how are you?')], + ) + ) + + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + async def test_empty_messages() -> None: """Test handling of empty messages.""" @@ -348,6 +377,80 @@ async def test_messages_with_history() -> None: assert events == simple_result() +async def test_img_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage( + id='msg_1', + content=[BinaryInputContent(url='https://example.com/img.png', mime_type='image/png', filename='img.png')], + ) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + +async def test_video_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage(id='msg_1', content=[BinaryInputContent(url='https://example.com/vid.mp4', mime_type='video/mp4')]) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + +async def test_audio_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage( + id='msg_1', content=[BinaryInputContent(url='https://example.com/audio.oga', mime_type='audio/ogg')] + ) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + +async def test_document_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage( + id='msg_1', + content=[BinaryInputContent(url='https://example.com/document.pdf', mime_type='application/pdf')], + ) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + +async def test_binary_file_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage(id='msg_1', content=[BinaryInputContent(data='VGVzdCBEb2M=', mime_type='text/plain')]) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + +async def test_test_and_binary_file_message() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + run_input = create_input( + UserMessage( + id='msg_1', + content=[ + TextInputContent(text='Write a summary of this file:', type='text'), + BinaryInputContent(data='VGVzdCBEb2M=', mime_type='text/plain'), + ], + ) + ) + events = await run_and_collect_events(agent, run_input) + + assert events == simple_result() + + async def test_tool_ag_ui() -> None: """Test AG-UI tool call.""" @@ -1576,6 +1679,221 @@ async def test_messages() -> None: ) +def test_load_messages_unsupported_user_content_part() -> None: + """Test load_messages with an unsupported content part in a UserMessage.""" + + class UnsupportedContentPart(BaseModel): + pass + + messages = [ + UserMessage.model_construct( + id='msg_1', + content=[UnsupportedContentPart()], + ) + ] + with pytest.raises(ValueError, match="Unsupported user message part type: "): + AGUIAdapter.load_messages(messages) + + +def test_load_messages_with_activity_message() -> None: + """Test that ActivityMessage is not supported.""" + messages = [ + ActivityMessage( + id='activity_1', + role='activity', + activity_type='PLAN', + content={'steps': ['Step 1', 'Step 2'], 'status': 'in_progress'}, + ) + ] + with pytest.raises(ValueError, match=r"Unsupported message type: "): + AGUIAdapter.load_messages(messages) + + +def test_load_messages_empty_user_content() -> None: + """Test load_messages with a UserMessage that has an empty content list.""" + messages = [ + UserMessage( + id='msg_1', + content=[], + ) + ] + loaded_messages = AGUIAdapter.load_messages(messages) + assert not loaded_messages + + +def test_load_binary_part() -> None: + """Test the _load_binary_part method of the AGUIAdapter.""" + # Test data URI + data_uri_part = BinaryInputContent( + url='', + mime_type='image/png', + ) + result = AGUIAdapter.load_binary_part(data_uri_part) + assert isinstance(result, BinaryContent) + assert result.data == b64decode( + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=' + ) + + # Test ImageUrl + image_url_part = BinaryInputContent(url='http://example.com/image.png', mime_type='image/png', id='img1') + result = AGUIAdapter.load_binary_part(image_url_part) + assert isinstance(result, ImageUrl) + assert result.url == 'http://example.com/image.png' + assert result.media_type == 'image/png' + assert result.identifier == 'img1' + + # Test VideoUrl + video_url_part = BinaryInputContent(url='http://example.com/video.mp4', mime_type='video/mp4', id='vid1') + result = AGUIAdapter.load_binary_part(video_url_part) + assert isinstance(result, VideoUrl) + assert result.url == 'http://example.com/video.mp4' + assert result.media_type == 'video/mp4' + assert result.identifier == 'vid1' + + # Test AudioUrl + audio_url_part = BinaryInputContent(url='http://example.com/audio.mp3', mime_type='audio/mpeg', id='aud1') + result = AGUIAdapter.load_binary_part(audio_url_part) + assert isinstance(result, AudioUrl) + assert result.url == 'http://example.com/audio.mp3' + assert result.media_type == 'audio/mpeg' + assert result.identifier == 'aud1' + + # Test DocumentUrl + doc_url_part = BinaryInputContent(url='http://example.com/doc.pdf', mime_type='application/pdf', id='doc1') + result = AGUIAdapter.load_binary_part(doc_url_part) + assert isinstance(result, DocumentUrl) + assert result.url == 'http://example.com/doc.pdf' + assert result.media_type == 'application/pdf' + assert result.identifier == 'doc1' + + # Test data field + data_part = BinaryInputContent(data='SGVsbG8gd29ybGQ=', mime_type='text/plain') + result = AGUIAdapter.load_binary_part(data_part) + assert isinstance(result, BinaryContent) + assert result.data == b'Hello world' + assert result.media_type == 'text/plain' + + # Test ValueError + with pytest.raises(ValueError, match='BinaryInputContent must have either a `url` or `data` field.'): + AGUIAdapter.load_binary_part(BinaryInputContent(id='some_id', mime_type='text/plain')) + + +def test_add_assistant_tool_parts() -> None: + """Test the _add_assistant_tool_parts method of the AGUIAdapter.""" + # Case 1: Regular tool call + builder = MessagesBuilder() + tool_calls_map: dict[str, str] = {} + regular_tool_call = ToolCall( + id='regular_call_1', + type='function', + function=FunctionCall(name='my_tool', arguments='{"arg": "value"}'), + ) + tool_calls_list = [regular_tool_call] + + AGUIAdapter.add_assistant_tool_parts(builder, tool_calls_list, tool_calls_map) + + assert len(builder.messages) == 1 + assert len(builder.messages[0].parts) == 1 + part = builder.messages[0].parts[0] + assert isinstance(part, ToolCallPart) + assert part.tool_name == 'my_tool' + assert part.tool_call_id == 'regular_call_1' + assert part.args == '{"arg": "value"}' + assert tool_calls_map == {'regular_call_1': 'my_tool'} + + # Case 2: Built-in tool call + builder = MessagesBuilder() + tool_calls_map = {} + builtin_id = f'{BUILTIN_TOOL_CALL_ID_PREFIX}|function|search_1' + builtin_tool_call = ToolCall( + id=builtin_id, + type='function', + function=FunctionCall(name='web_search', arguments='{"query": "test"}'), + ) + tool_calls_list = [builtin_tool_call] + + AGUIAdapter.add_assistant_tool_parts(builder, tool_calls_list, tool_calls_map) + + assert len(builder.messages) == 1 + assert len(builder.messages[0].parts) == 1 + part = builder.messages[0].parts[0] + assert isinstance(part, BuiltinToolCallPart) + assert part.tool_name == 'web_search' + assert part.tool_call_id == 'search_1' + assert part.provider_name == 'function' + assert part.args == '{"query": "test"}' + assert tool_calls_map == {builtin_id: 'web_search'} + + # Case 3: Mixed tool calls + builder = MessagesBuilder() + tool_calls_map = {} + tool_calls_list = [regular_tool_call, builtin_tool_call] + + AGUIAdapter.add_assistant_tool_parts(builder, tool_calls_list, tool_calls_map) + + assert len(builder.messages) == 1 + assert len(builder.messages[0].parts) == 2 + regular_part, builtin_part = builder.messages[0].parts + + assert isinstance(regular_part, ToolCallPart) + assert regular_part.tool_call_id == 'regular_call_1' + + assert isinstance(builtin_part, BuiltinToolCallPart) + assert builtin_part.tool_call_id == 'search_1' + + assert tool_calls_map == {'regular_call_1': 'my_tool', builtin_id: 'web_search'} + + # Case 4: Empty list + builder = MessagesBuilder() + tool_calls_map = {} + AGUIAdapter.add_assistant_tool_parts(builder, [], tool_calls_map) + + assert not builder.messages + assert not tool_calls_map + + +def test_add_tool_return_part() -> None: + """Test the _add_tool_return_part method of the AGUIAdapter.""" + # Case 1: Regular tool return + builder = MessagesBuilder() + tool_calls_map = {'call_1': 'my_tool'} + msg = ToolMessage(id='msg_1', tool_call_id='call_1', content='result content') + + AGUIAdapter.add_tool_return_part(builder, msg, tool_calls_map) + + assert len(builder.messages) == 1 + assert len(builder.messages[0].parts) == 1 + part = builder.messages[0].parts[0] + assert isinstance(part, ToolReturnPart) + assert part.tool_name == 'my_tool' + assert part.content == 'result content' + assert part.tool_call_id == 'call_1' + + # Case 2: Built-in tool return + builder = MessagesBuilder() + builtin_id = f'{BUILTIN_TOOL_CALL_ID_PREFIX}|function|search_1' + tool_calls_map = {builtin_id: 'web_search'} + msg = ToolMessage(id='msg_2', tool_call_id=builtin_id, content='search results') + + AGUIAdapter.add_tool_return_part(builder, msg, tool_calls_map) + + assert len(builder.messages) == 1 + assert len(builder.messages[0].parts) == 1 + part = builder.messages[0].parts[0] + assert isinstance(part, BuiltinToolReturnPart) + assert part.tool_name == 'web_search' + assert part.content == 'search results' + assert part.tool_call_id == 'search_1' + assert part.provider_name == 'function' + + # Case 3: ValueError for missing tool call ID + builder = MessagesBuilder() + tool_calls_map: dict[str, str] = {} + msg = ToolMessage(id='msg_3', tool_call_id='non_existent_call', content='some content') + with pytest.raises(ValueError, match='Tool call with ID non_existent_call not found in the history.'): + AGUIAdapter.add_tool_return_part(builder, msg, tool_calls_map) + + async def test_builtin_tool_call() -> None: async def stream_function( messages: list[ModelMessage], agent_info: AgentInfo diff --git a/uv.lock b/uv.lock index 0c1e48a65f..c0ba666566 100644 --- a/uv.lock +++ b/uv.lock @@ -44,14 +44,14 @@ wheels = [ [[package]] name = "ag-ui-protocol" -version = "0.1.8" +version = "0.1.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/de/0bddf7f26d5f38274c99401735c82ad59df9cead6de42f4bb2ad837286fe/ag_ui_protocol-0.1.8.tar.gz", hash = "sha256:eb745855e9fc30964c77e953890092f8bd7d4bbe6550d6413845428dd0faac0b", size = 5323, upload-time = "2025-07-15T10:55:36.389Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/bb/5a5ec893eea5805fb9a3db76a9888c3429710dfb6f24bbb37568f2cf7320/ag_ui_protocol-0.1.10.tar.gz", hash = "sha256:3213991c6b2eb24bb1a8c362ee270c16705a07a4c5962267a083d0959ed894f4", size = 6945, upload-time = "2025-11-06T15:17:17.068Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/00/40c6b0313c25d1ab6fac2ecba1cd5b15b1cd3c3a71b3d267ad890e405889/ag_ui_protocol-0.1.8-py3-none-any.whl", hash = "sha256:1567ccb067b7b8158035b941a985e7bb185172d660d4542f3f9c6fff77b55c6e", size = 7066, upload-time = "2025-07-15T10:55:35.075Z" }, + { url = "https://files.pythonhosted.org/packages/8f/78/eb55fabaab41abc53f52c0918a9a8c0f747807e5306273f51120fd695957/ag_ui_protocol-0.1.10-py3-none-any.whl", hash = "sha256:c81e6981f30aabdf97a7ee312bfd4df0cd38e718d9fc10019c7d438128b93ab5", size = 7889, upload-time = "2025-11-06T15:17:15.325Z" }, ] [[package]] @@ -5647,7 +5647,7 @@ vertexai = [ [package.metadata] requires-dist = [ - { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.8" }, + { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.10" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.75.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.40.14" },