diff --git a/docs/agents.md b/docs/agents.md index adbbbefc98..d70f7b76e9 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -15,6 +15,7 @@ The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptua | [Dependency type constraint](dependencies.md) | Dynamic instructions functions, tools, and output functions may all use dependencies when they're run. | | [LLM model](api/models/base.md) | Optional default LLM model associated with the agent. Can also be specified when running the agent. | | [Model Settings](#additional-configuration) | Optional default model settings to help fine tune requests. Can also be specified when running the agent. | +| [Prompt Configuration](#prompt-configuration) | Optional configuration for customizing system-generated messages, tool descriptions, and retry prompts. | In typing terms, agents are generic in their dependency and output types, e.g., an agent which required dependencies of type `#!python Foobar` and produced outputs of type `#!python list[str]` would have type `Agent[Foobar, list[str]]`. In practice, you shouldn't need to care about this, it should just mean your IDE can tell you when you have the right type, and if you choose to use [static type checking](#static-type-checking) it should work well with Pydantic AI. @@ -751,6 +752,125 @@ except UnexpectedModelBehavior as e: 1. This error is raised because the safety thresholds were exceeded. +### Prompt Configuration + +Pydantic AI provides [`PromptConfig`][pydantic_ai.PromptConfig] to customize the system-generated messages +that are sent to models during agent runs. This includes retry prompts, tool return confirmations, +validation error messages, and tool descriptions. + +#### Customizing System Messages with PromptTemplates + +[`PromptTemplates`][pydantic_ai.PromptTemplates] allows you to override the default messages that Pydantic AI +sends to the model for retries, tool results, and other system-generated content. + +```python {title="prompt_templates_example.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates + +# Using static strings +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Please correct the validation errors and try again.', + final_result_processed='Result received successfully.', + ), + ), +) +``` + +You can also use callable functions for dynamic messages that have access to the message part +and the [`RunContext`][pydantic_ai.RunContext]: + +```python {title="prompt_templates_dynamic.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates +from pydantic_ai.messages import RetryPromptPart +from pydantic_ai.tools import RunContext + + +def custom_retry_message(part: RetryPromptPart, ctx: RunContext) -> str: + return f'Attempt #{ctx.retries + 1}: Please fix the errors and try again.' + +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry=custom_retry_message, + ), + ), +) +``` + +The available template fields in [`PromptTemplates`][pydantic_ai.PromptTemplates] include: + +| Template Field | Description | +|----------------|-------------| +| `final_result_processed` | Confirmation message when a final result is successfully processed | +| `output_tool_not_executed` | Message when an output tool call is skipped because a result was already found | +| `function_tool_not_executed` | Message when a function tool call is skipped because a result was already found | +| `tool_call_denied` | Message when a tool call is denied by an approval handler | +| `validation_errors_retry` | Message appended to validation errors when asking the model to retry | +| `model_retry_string_tool` | Message when a `ModelRetry` exception is raised from a tool | +| `model_retry_string_no_tool` | Message when a `ModelRetry` exception is raised outside of a tool context | + +#### Customizing Tool Descriptions with ToolConfig + +[`ToolConfig`][pydantic_ai.ToolConfig] allows you to override tool descriptions at runtime without modifying +the original tool definitions. This is useful when you want to provide different descriptions for the same +tool in different contexts or agent runs. + +```python {title="tool_config_example.py"} +from pydantic_ai import Agent, PromptConfig, ToolConfig + +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + tool_config=ToolConfig( + tool_descriptions={ + 'search_database': 'Search the customer database for user records by name or email.', + 'send_notification': 'Send an urgent notification to the user via their preferred channel.', + } + ), + ), +) + + +@agent.tool_plain +def search_database(query: str) -> list[str]: + """Original description that will be overridden.""" + return ['result1', 'result2'] + + +@agent.tool_plain +def send_notification(user_id: str, message: str) -> bool: + """Original description that will be overridden.""" + return True +``` + +You can also override `prompt_config` at runtime using the `prompt_config` parameter in the run methods, +or temporarily using [`agent.override()`][pydantic_ai.Agent.override]: + +```python {title="prompt_config_override.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates + +agent = Agent('openai:gpt-5') + +# Override at runtime +result = agent.run_sync( + 'Hello', + prompt_config=PromptConfig( + templates=PromptTemplates(validation_errors_retry='Custom retry message for this run.') + ), +) + +# Or use agent.override() context manager +with agent.override( + prompt_config=PromptConfig( + templates=PromptTemplates(validation_errors_retry='Another custom message.') + ) +): + result = agent.run_sync('Hello') +``` + ## Runs vs. Conversations An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls. @@ -1072,6 +1192,7 @@ with capture_run_messages() as messages: # (2)! tool_name='calc_volume', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + retry_message='Fix the errors and try again.', ) ], run_id='...', diff --git a/docs/api/prompt_config.md b/docs/api/prompt_config.md new file mode 100644 index 0000000000..1b677292a7 --- /dev/null +++ b/docs/api/prompt_config.md @@ -0,0 +1,9 @@ +# `pydantic_ai.prompt_config` + +::: pydantic_ai.prompt_config + options: + inherited_members: true + members: + - PromptConfig + - PromptTemplates + - ToolConfig diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index 31e14149c0..97c3b3e37d 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -150,6 +150,7 @@ print(result.all_messages()) content="File 'README.md' updated: 'Hello, world!'", tool_call_id='update_file_readme', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -161,12 +162,14 @@ print(result.all_messages()) content="File '.env' updated: ''", tool_call_id='update_file_dotenv', timestamp=datetime.datetime(...), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', content='Deleting files is not allowed', tool_call_id='delete_file', timestamp=datetime.datetime(...), + return_kind='tool-denied', ), UserPromptPart( content='Now create a backup of README.md', @@ -195,6 +198,7 @@ print(result.all_messages()) content="File 'README.md.bak' updated: 'Hello, world!'", tool_call_id='update_file_backup', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -348,6 +352,7 @@ async def main(): content=42, tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', diff --git a/docs/testing.md b/docs/testing.md index 3089585ab0..99d2e01472 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -156,6 +156,7 @@ async def test_forecast(): content='Sunny with a chance of rain', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ], run_id=IsStr(), diff --git a/docs/tools.md b/docs/tools.md index 40dcf5c810..38574ff819 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -108,6 +108,7 @@ print(dice_result.all_messages()) content='4', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -130,6 +131,7 @@ print(dice_result.all_messages()) content='Anne', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', diff --git a/mkdocs.yml b/mkdocs.yml index 1b46d5250c..fa9bb34840 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -163,6 +163,7 @@ nav: - api/models/test.md - api/models/wrapper.md - api/profiles.md + - api/prompt_config.md - api/providers.md - api/retries.md - api/run.md diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index c860d20dd8..97191a7ac8 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -94,6 +94,7 @@ ModelProfile, ModelProfileSpec, ) +from .prompt_config import PromptConfig, PromptTemplates, ToolConfig from .run import AgentRun, AgentRunResult, AgentRunResultEvent from .settings import ModelSettings from .tools import DeferredToolRequests, DeferredToolResults, RunContext, Tool, ToolApproved, ToolDefinition, ToolDenied @@ -229,6 +230,10 @@ 'PromptedOutput', 'TextOutput', 'StructuredDict', + # prompt_config + 'PromptConfig', + 'PromptTemplates', + 'ToolConfig', # format_prompt 'format_as_xml', # settings diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 92c45a0c52..ad417d575a 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -25,7 +25,16 @@ from pydantic_graph.beta import Graph, GraphBuilder from pydantic_graph.nodes import End, NodeRunEndT -from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from . import ( + _output, + _system_prompt, + exceptions, + messages as _messages, + models, + prompt_config as _prompt_config, + result, + usage as _usage, +) from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings @@ -133,6 +142,9 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): model: models.Model model_settings: ModelSettings | None + prompt_config: _prompt_config.PromptConfig = dataclasses.field( + default_factory=lambda: _prompt_config.DEFAULT_PROMPT_CONFIG + ) usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy @@ -379,9 +391,8 @@ async def _prepare_request_parameters( """Build tools and create an agent model.""" output_schema = ctx.deps.output_schema - prompted_output_template = ( - output_schema.template if isinstance(output_schema, _output.PromptedOutputSchema) else None - ) + prompt_config = ctx.deps.prompt_config + prompted_output_template = prompt_config.templates.get_prompted_output_template(output_schema) function_tools: list[ToolDefinition] = [] output_tools: list[ToolDefinition] = [] @@ -504,6 +515,14 @@ async def _prepare_request( # Update the new message index to ensure `result.new_messages()` returns the correct messages ctx.deps.new_message_index -= len(original_history) - len(message_history) + prompt_config = ctx.deps.prompt_config + + message_history = _apply_prompt_templates_to_message_history( + message_history, prompt_config.templates, run_context + ) + + ctx.state.message_history[:] = message_history + # Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts, # but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries. # See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary @@ -780,6 +799,8 @@ def _handle_final_result( # For backwards compatibility, append a new ModelRequest using the tool returns and retries if tool_responses: + run_ctx = build_run_context(ctx) + tool_responses = [ctx.deps.prompt_config.templates.apply_template(part, run_ctx) for part in tool_responses] messages.append(_messages.ModelRequest(parts=tool_responses, run_id=ctx.state.run_id)) return End(final_result) @@ -865,8 +886,9 @@ async def process_tool_calls( # noqa: C901 if final_result and final_result.tool_call_id == call.tool_call_id: part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Final result processed.', + content=_prompt_config.DEFAULT_PROMPT_CONFIG.templates.final_result_processed, tool_call_id=call.tool_call_id, + return_kind='final-result-processed', ) output_parts.append(part) # Early strategy is chosen and final result is already set @@ -874,8 +896,9 @@ async def process_tool_calls( # noqa: C901 yield _messages.FunctionToolCallEvent(call) part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Output tool not used - a final result was already processed.', + content=_prompt_config.DEFAULT_PROMPT_CONFIG.templates.output_tool_not_executed, tool_call_id=call.tool_call_id, + return_kind='output-tool-not-executed', ) yield _messages.FunctionToolResultEvent(part) output_parts.append(part) @@ -916,8 +939,9 @@ async def process_tool_calls( # noqa: C901 else: part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Final result processed.', + content=_prompt_config.DEFAULT_PROMPT_CONFIG.templates.final_result_processed, tool_call_id=call.tool_call_id, + return_kind='final-result-processed', ) output_parts.append(part) @@ -932,8 +956,9 @@ async def process_tool_calls( # noqa: C901 output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', + content=_prompt_config.DEFAULT_PROMPT_CONFIG.templates.function_tool_not_executed, tool_call_id=call.tool_call_id, + return_kind='function-tool-not-executed', ) ) else: @@ -990,8 +1015,9 @@ async def process_tool_calls( # noqa: C901 output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', + content=_prompt_config.DEFAULT_PROMPT_CONFIG.templates.function_tool_not_executed, tool_call_id=call.tool_call_id, + return_kind='function-tool-not-executed', ) ) elif calls: @@ -1148,6 +1174,7 @@ async def _call_tool( tool_name=tool_call.tool_name, content=tool_call_result.message, tool_call_id=tool_call.tool_call_id, + return_kind='tool-denied', ), None elif isinstance(tool_call_result, exceptions.ModelRetry): m = _messages.RetryPromptPart( @@ -1210,6 +1237,7 @@ async def _call_tool( tool_call_id=tool_call.tool_call_id, content=tool_return.return_value, # type: ignore metadata=tool_return.metadata, + return_kind='tool-executed', ) return return_part, tool_return.content or None @@ -1380,3 +1408,18 @@ def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_mess else: clean_messages.append(message) return clean_messages + + +def _apply_prompt_templates_to_message_history( + messages: list[_messages.ModelMessage], prompt_templates: _prompt_config.PromptTemplates, ctx: RunContext[Any] +) -> list[_messages.ModelMessage]: + messages_with_templates_applied: list[_messages.ModelMessage] = [] + + for msg in messages: + if isinstance(msg, _messages.ModelRequest): + parts_template_applied = [prompt_templates.apply_template(part, ctx) for part in msg.parts] + messages_with_templates_applied.append(replace(msg, parts=parts_template_applied)) + else: + messages_with_templates_applied.append(msg) + + return messages_with_templates_applied diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 85ed332d0b..c0b9624096 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -24,6 +24,7 @@ exceptions, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from .._agent_graph import ( @@ -129,6 +130,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ + prompt_config: _prompt_config.PromptConfig | None + """Optional prompt configuration used to customize the system-injected messages for this agent.""" + _output_type: OutputSpec[OutputDataT] instrument: InstrumentationSettings | bool | None @@ -172,6 +176,7 @@ def __init__( deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, retries: int = 1, validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, @@ -226,6 +231,7 @@ def __init__( deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, retries: int = 1, validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, @@ -260,6 +266,8 @@ def __init__( name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. + prompt_config: Optional prompt configuration to customize how system-injected messages + (like retry prompts or tool return wrappers) are rendered for this agent. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate tool arguments and outputs. @@ -306,6 +314,7 @@ def __init__( self._name = name self.end_strategy = end_strategy self.model_settings = model_settings + self.prompt_config = prompt_config self._output_type = output_type self.instrument = instrument @@ -372,6 +381,9 @@ def __init__( self._override_instructions: ContextVar[ _utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]] ] = ContextVar('_override_instructions', default=None) + self._override_prompt_config: ContextVar[_utils.Option[_prompt_config.PromptConfig]] = ContextVar( + '_override_prompt_config', default=None + ) self._enter_lock = Lock() self._entered_count = 0 @@ -439,6 +451,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -458,6 +471,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -477,6 +491,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -553,6 +568,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are + phrased for this specific run, falling back to the agent's defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -577,6 +594,7 @@ async def main(): # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. output_validators = self._output_validators + prompt_config = self._get_prompt_config(prompt_config) output_toolset = self._output_toolset if output_schema != self._output_schema or output_validators: @@ -584,7 +602,9 @@ async def main(): if output_toolset: output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + toolset = self._get_toolset( + output_toolset=output_toolset, additional_toolsets=toolsets, prompt_config=prompt_config + ) tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph @@ -630,6 +650,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: new_message_index=len(message_history) if message_history else 0, model=model_used, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, max_result_retries=self._max_result_retries, end_strategy=self.end_strategy, @@ -764,6 +785,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -777,6 +799,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt config to use instead of the prompt config passed to the agent constructor and agent run. """ if _utils.is_set(name): name_token = self._override_name.set(_utils.Some(name)) @@ -809,6 +832,11 @@ def override( else: instructions_token = None + if _utils.is_set(prompt_config): + prompt_config_token = self._override_prompt_config.set(_utils.Some(prompt_config)) + else: + prompt_config_token = None + try: yield finally: @@ -824,6 +852,8 @@ def override( self._override_tools.reset(tools_token) if instructions_token is not None: self._override_instructions.reset(instructions_token) + if prompt_config_token is not None: + self._override_prompt_config.reset(prompt_config_token) @overload def instructions( @@ -1348,6 +1378,18 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_prompt_config(self, prompt_config: _prompt_config.PromptConfig | None) -> _prompt_config.PromptConfig: + """Get prompt_config for a run. + + If we've overridden prompt_config via `_override_prompt_config`, use that, + otherwise use the prompt_config passed to the call, falling back to the agent default, + and finally falling back to the global default. + """ + if some_prompt_config := self._override_prompt_config.get(): + return some_prompt_config.value + else: + return prompt_config or self.prompt_config or _prompt_config.DEFAULT_PROMPT_CONFIG + def _normalize_instructions( self, instructions: Instructions[AgentDepsT], @@ -1386,12 +1428,14 @@ def _get_toolset( self, output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, ) -> AbstractToolset[AgentDepsT]: """Get the complete toolset. Args: output_toolset: The output toolset to use instead of the one built at agent construction time. additional_toolsets: Additional toolsets to add, unless toolsets have been overridden. + prompt_config: The prompt config to use for tool descriptions. If None, uses agent-level or default. """ toolsets = self.toolsets # Don't add additional toolsets if the toolsets have been overridden @@ -1408,14 +1452,15 @@ def copy_dynamic_toolsets(toolset: AbstractToolset[AgentDepsT]) -> AbstractTools return toolset toolset = toolset.visit_and_replace(copy_dynamic_toolsets) + tool_config = self._get_prompt_config(prompt_config).tool_config - if self._prepare_tools: - toolset = PreparedToolset(toolset, self._prepare_tools) + if self._prepare_tools or tool_config: + toolset = PreparedToolset(toolset, self._prepare_tools, tool_config=tool_config) output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset if output_toolset is not None: - if self._prepare_output_tools: - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + if self._prepare_output_tools or tool_config: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools, tool_config=tool_config) toolset = CombinedToolset([output_toolset, toolset]) return toolset diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index cc99f80e74..dfae29f20d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -21,6 +21,7 @@ exceptions, messages as _messages, models, + prompt_config as _prompt_config, result, usage as _usage, ) @@ -160,6 +161,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -180,6 +182,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -199,6 +202,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -233,6 +237,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -257,6 +263,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, toolsets=toolsets, @@ -284,6 +291,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -304,6 +312,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -323,6 +332,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -356,6 +366,8 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -379,6 +391,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -400,6 +413,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -420,6 +434,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -440,6 +455,7 @@ async def run_stream( # noqa: C901 instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -481,6 +497,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -510,6 +528,7 @@ async def main(): deps=deps, instructions=instructions, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -632,6 +651,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -651,6 +671,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -669,6 +690,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -712,6 +734,8 @@ def main(): model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -736,6 +760,7 @@ async def _consume_stream(): model=model, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -760,6 +785,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -779,6 +805,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -797,6 +824,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -847,6 +875,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -872,6 +902,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, toolsets=toolsets, @@ -889,6 +920,7 @@ async def _run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, @@ -915,6 +947,7 @@ async def run_agent() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -944,6 +977,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -963,6 +997,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -983,6 +1018,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -1059,6 +1095,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -1082,6 +1120,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -1095,6 +1134,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ raise NotImplementedError yield diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index f363b5d990..7dc0d63e92 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -8,6 +8,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from .._json_schema import JsonSchema @@ -84,6 +85,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -103,6 +105,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -122,6 +125,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -198,6 +202,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent's defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -216,6 +222,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -234,6 +241,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -247,6 +255,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config passed to the agent constructor and agent run. """ with self.wrapped.override( name=name, @@ -255,5 +264,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index c5adf5221d..4eca4619b8 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -13,6 +13,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -136,6 +137,7 @@ async def wrapped_run_workflow( deps: AgentDepsT, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, @@ -153,6 +155,7 @@ async def wrapped_run_workflow( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -177,6 +180,7 @@ def wrapped_run_sync_workflow( model_settings: ModelSettings | None = None, instructions: Instructions[AgentDepsT] = None, usage_limits: _usage.UsageLimits | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, @@ -195,6 +199,7 @@ def wrapped_run_sync_workflow( deps=deps, model_settings=model_settings, usage_limits=usage_limits, + prompt_config=prompt_config, usage=usage, infer_name=infer_name, toolsets=toolsets, @@ -268,6 +273,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -288,6 +294,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -307,6 +314,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -343,6 +351,7 @@ async def main(): deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. @@ -365,6 +374,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -386,6 +396,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -406,6 +417,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -425,6 +437,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -459,6 +472,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -482,6 +496,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -503,6 +518,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -523,6 +539,7 @@ def run_stream( deps: AgentDepsT = None, instructions: Instructions[AgentDepsT] = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -622,6 +639,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -641,6 +659,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -659,6 +678,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -709,6 +729,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -736,6 +757,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -756,6 +778,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -776,6 +799,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -853,6 +877,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -877,6 +902,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -896,6 +922,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -909,6 +936,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if _utils.is_set(model) and not isinstance(model, (DBOSModel)): raise UserError( @@ -922,5 +950,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py index 60c8122686..711d4ae247 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py @@ -16,6 +16,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -184,6 +185,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -204,6 +206,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -223,6 +226,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -258,6 +262,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -284,6 +289,7 @@ async def wrapped_run_flow() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -308,6 +314,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -328,6 +335,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -347,6 +355,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -381,6 +390,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -409,6 +419,7 @@ def wrapped_run_sync_flow() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -434,6 +445,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -454,6 +466,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -474,6 +487,7 @@ async def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -506,6 +520,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -531,6 +546,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -553,6 +569,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -572,6 +589,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -590,6 +608,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -640,6 +659,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -665,6 +685,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -684,6 +705,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -703,6 +725,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -722,6 +745,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -798,6 +822,7 @@ async def main(): deps: Optional dependencies to use for this run. instructions: Optional additional instructions to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -822,6 +847,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -839,6 +865,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. @@ -852,6 +879,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if _utils.is_set(model) and not isinstance(model, PrefectModel): raise UserError( @@ -859,6 +887,12 @@ def override( ) with super().override( - name=name, deps=deps, model=model, toolsets=toolsets, tools=tools, instructions=instructions + name=name, + deps=deps, + model=model, + toolsets=toolsets, + tools=tools, + instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 42fc2a872e..1f8b5aace0 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -21,6 +21,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -267,6 +268,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -287,6 +289,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -306,6 +309,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -341,6 +345,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -366,6 +371,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -387,6 +393,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -408,6 +415,7 @@ def run_sync( deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, @@ -426,6 +434,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -460,6 +469,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -484,6 +494,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -505,6 +516,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -525,6 +537,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -545,6 +558,7 @@ async def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -577,6 +591,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -602,6 +617,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -624,6 +640,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -643,6 +660,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -661,6 +679,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -711,6 +730,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -736,6 +756,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -755,6 +776,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -775,6 +797,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -795,6 +818,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -872,6 +896,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -906,6 +931,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -925,6 +951,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -938,6 +965,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if workflow.in_workflow(): if _utils.is_set(model): @@ -960,5 +988,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 826cf754b2..ab9249db3a 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -874,6 +874,26 @@ class ToolReturnPart(BaseToolReturnPart): part_kind: Literal['tool-return'] = 'tool-return' """Part type identifier, this is available on all parts as a discriminator.""" + return_kind: ( + Literal[ + 'final-result-processed', + 'output-tool-not-executed', + 'function-tool-not-executed', + 'tool-executed', + 'tool-denied', + ] + | None + ) = None + """How the tool call was resolved, used for disambiguating return parts. + + * `tool-executed`: the tool ran successfully and produced a return value + * `final-result-processed`: an output tool produced the run's final result + * `output-tool-not-executed`: an output tool was skipped because a final result already existed + * `function-tool-not-executed`: a function tool was skipped due to early termination after a final result + * `tool-denied`: the tool call was rejected by an approval handler + + """ + @dataclass(repr=False) class BuiltinToolReturnPart(BaseToolReturnPart): @@ -896,6 +916,12 @@ class BuiltinToolReturnPart(BaseToolReturnPart): error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True)) +def _get_default_model_retry_message() -> str: + from .prompt_config import DEFAULT_PROMPT_CONFIG + + return cast(str, DEFAULT_PROMPT_CONFIG.templates.default_model_retry) + + @dataclass(repr=False) class RetryPromptPart: """A message back to a model asking it to try again. @@ -936,6 +962,9 @@ class RetryPromptPart: part_kind: Literal['retry-prompt'] = 'retry-prompt' """Part type identifier, this is available on all parts as a discriminator.""" + retry_message: str | None = field(default_factory=_get_default_model_retry_message) + """The retry message rendered using the user's prompt template. It is populated after checking the conditions for the retry so that the correct template is used.""" + def model_response(self) -> str: """Return a string message describing why the retry is requested.""" if isinstance(self.content, str): @@ -949,7 +978,8 @@ def model_response(self) -> str: description = ( f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```' ) - return f'{description}\n\nFix the errors and try again.' + + return f'{description}\n\n{self.retry_message}' def otel_event(self, settings: InstrumentationSettings) -> LogRecord: if self.tool_name is None: diff --git a/pydantic_ai_slim/pydantic_ai/prompt_config.py b/pydantic_ai_slim/pydantic_ai/prompt_config.py new file mode 100644 index 0000000000..b62e899971 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/prompt_config.py @@ -0,0 +1,232 @@ +from __future__ import annotations as _annotations + +from collections.abc import Callable +from dataclasses import dataclass, field, replace +from textwrap import dedent +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ._output import OutputSchema + from ._run_context import RunContext as _RunContext + + +from .messages import ModelRequestPart, RetryPromptPart, ToolReturnPart + + +@dataclass +class PromptTemplates: + """Templates for customizing system-generated messages that Pydantic AI sends to models. + + Each template can be either: + - A static string that replaces the default message + - A callable that receives the message part and [`RunContext`][pydantic_ai.RunContext] + and returns a dynamically generated string + + These templates are used within [`PromptConfig`][pydantic_ai.PromptConfig] to customize + retry prompts, tool return confirmations, validation error messages, and more. + + Example: + ```python + from pydantic_ai import Agent, PromptConfig, PromptTemplates + + # Using static strings + templates = PromptTemplates( + validation_errors_retry='Please fix the validation errors.', + final_result_processed='Done!', + ) + + # Using callable for dynamic messages + templates = PromptTemplates( + validation_errors_retry=lambda part, ctx: f'Retry #{ctx.retries}: Fix the errors.', + ) + + agent = Agent('openai:gpt-4o', prompt_config=PromptConfig(templates=templates)) + ``` + """ + + final_result_processed: str | Callable[[ToolReturnPart, _RunContext[Any]], str] = 'Final result processed.' + """Confirmation message sent when a final result is successfully processed.""" + + output_tool_not_executed: str | Callable[[ToolReturnPart, _RunContext[Any]], str] = ( + 'Output tool not used - a final result was already processed.' + ) + """Message sent when an output tool call is skipped because a result was already found.""" + + function_tool_not_executed: str | Callable[[ToolReturnPart, _RunContext[Any]], str] = ( + 'Tool not executed - a final result was already processed.' + ) + """Message sent when a function tool call is skipped because a result was already found.""" + + tool_call_denied: str | Callable[[ToolReturnPart, _RunContext[Any]], str] = 'The tool call was denied.' + """Message sent when a tool call is denied by an approval handler. + + Note: Custom messages set via `ToolDenied` are preserved unless this template is explicitly overridden. + """ + + default_model_retry: str | Callable[[RetryPromptPart, _RunContext[Any]], str] = 'Fix the errors and try again.' + """Default message sent when a `ModelRetry` exception is raised.""" + + validation_errors_retry: str | Callable[[RetryPromptPart, _RunContext[Any]], str] = 'Fix the errors and try again.' + """Message appended to validation errors when asking the model to retry.""" + + model_retry_string_tool: str | Callable[[RetryPromptPart, _RunContext[Any]], str] = 'Fix the errors and try again.' + """Message sent when a `ModelRetry` exception is raised from a tool.""" + + model_retry_string_no_tool: str | Callable[[RetryPromptPart, _RunContext[Any]], str] = ( + 'Fix the errors and try again.' + ) + """Message sent when a `ModelRetry` exception is raised outside of a tool context.""" + + prompted_output_template: str = dedent( + """ + Always respond with a JSON object that's compatible with this schema: + + {schema} + + Don't include any text or Markdown fencing before or after. + """ + ) + + def apply_template(self, message_part: ModelRequestPart, ctx: _RunContext[Any]) -> ModelRequestPart: + if isinstance(message_part, ToolReturnPart): + if message_part.return_kind == 'final-result-processed': + return self._apply_tool_template(message_part, ctx, self.final_result_processed) + elif message_part.return_kind == 'output-tool-not-executed': + return self._apply_tool_template(message_part, ctx, self.output_tool_not_executed) + elif message_part.return_kind == 'function-tool-not-executed': + return self._apply_tool_template(message_part, ctx, self.function_tool_not_executed) + elif message_part.return_kind == 'tool-denied': + # The content may already have a custom message from ToolDenied in which case we should not override it + if self.tool_call_denied != DEFAULT_PROMPT_CONFIG.templates.tool_call_denied: + return self._apply_tool_template(message_part, ctx, self.tool_call_denied) + return message_part + elif isinstance(message_part, RetryPromptPart): + template = self._get_template_for_retry(message_part) + return self._apply_retry_tempelate(message_part, ctx, template) + return message_part # Returns the original message if no template is applied + + def _get_template_for_retry( + self, message_part: RetryPromptPart + ) -> str | Callable[[RetryPromptPart, _RunContext[Any]], str]: + template: str | Callable[[RetryPromptPart, _RunContext[Any]], str] = self.default_model_retry + # This is based no RetryPromptPart.model_response() implementation + # We follow the same structure here to populate the correct template + if isinstance(message_part.content, str): + if message_part.tool_name is None: + template = self.model_retry_string_no_tool + else: + template = self.model_retry_string_tool + else: + template = self.validation_errors_retry + + return template + + def _apply_retry_tempelate( + self, + message_part: RetryPromptPart, + ctx: _RunContext[Any], + template: str | Callable[[RetryPromptPart, _RunContext[Any]], str], + ) -> RetryPromptPart: + if isinstance(template, str): + message_part = replace(message_part, retry_message=template) + else: + message_part = replace(message_part, retry_message=template(message_part, ctx)) + + return message_part + + def _apply_tool_template( + self, + message_part: ToolReturnPart, + ctx: _RunContext[Any], + template: str | Callable[[ToolReturnPart, _RunContext[Any]], str], + ) -> ToolReturnPart: + if isinstance(template, str): + message_part = replace(message_part, content=template) + + else: + message_part = replace(message_part, content=template(message_part, ctx)) + return message_part + + def get_prompted_output_template(self, output_schema: OutputSchema[Any]) -> str | None: + """Get the prompted output template for the given output schema.""" + from ._output import PromptedOutputSchema + + if not isinstance(output_schema, PromptedOutputSchema): + return None + + return self.prompted_output_template + + +@dataclass +class ToolConfig: + """Configuration for customizing tool descriptions, arguments used by agents.""" + + tool_descriptions: dict[str, str] = field(default_factory=lambda: {}) + """Custom descriptions for tools used by the agent.""" + + +@dataclass +class PromptConfig: + """Configuration for customizing all strings and prompts sent to the model by Pydantic AI. + + `PromptConfig` provides a clean, extensible interface for overriding any text that + Pydantic AI sends to the model. This includes: + + - **Prompt Templates**: Messages for retry prompts, tool return confirmations, + validation errors, and other system-generated text via [`PromptTemplates`][pydantic_ai.PromptTemplates]. + - **Tool Configuration** (planned): Tool descriptions, parameter descriptions, and other + tool metadata - allowing you to override descriptions and args for tools at the agent level. + + This allows you to fully customize how your agent communicates with the model + without modifying the underlying tool or agent code. + + Example: + ```python + from pydantic_ai import Agent, PromptConfig, PromptTemplates + + agent = Agent( + 'openai:gpt-4o', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Please correct the errors and try again.', + final_result_processed='Result received successfully.', + ), + ), + ) + ``` + + Attributes: + templates: Templates for customizing system-generated messages like retry prompts, + tool return confirmations, and validation error messages. + """ + + templates: PromptTemplates = field(default_factory=PromptTemplates) + """Templates for customizing system-generated messages sent to the model. + + See [`PromptTemplates`][pydantic_ai.PromptTemplates] for available template options. + """ + + tool_config: ToolConfig | None = None + """Configuration for customizing tool descriptions and metadata. + See [`ToolConfig`][pydantic_ai.ToolConfig] for available configuration options. + """ + + +# @dataclass +# class InstructionsConfig: +# """ +# Configuration options to override instructions sent to the model. +# """ + + # instructions: + # It seems like runtime instuctions being passed to one of the run methods almost do the same thing. + # Why do we need this then? + + + +DEFAULT_PROMPT_CONFIG = PromptConfig() +"""The default prompt configuration used when no custom configuration is provided. + +This uses the default [`PromptTemplates`][pydantic_ai.PromptTemplates] with sensible +defaults for all system-generated messages. +""" diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 900278ce44..f4d55b3b83 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -14,6 +14,7 @@ from .builtin_tools import AbstractBuiltinTool from .exceptions import ModelRetry from .messages import RetryPromptPart, ToolCallPart, ToolReturn +from .prompt_config import DEFAULT_PROMPT_CONFIG __all__ = ( 'AgentDepsT', @@ -176,7 +177,7 @@ class ToolApproved: class ToolDenied: """Indicates that a tool call has been denied and that a denial message should be returned to the model.""" - message: str = 'The tool call was denied.' + message: str = cast(str, DEFAULT_PROMPT_CONFIG.templates.tool_call_denied) """The message to return to the model.""" _: KW_ONLY diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py index af604d4328..31951cf524 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -2,6 +2,8 @@ from dataclasses import dataclass, replace +from pydantic_ai.prompt_config import ToolConfig + from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError from ..tools import ToolsPrepareFunc @@ -16,10 +18,41 @@ class PreparedToolset(WrapperToolset[AgentDepsT]): See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information. """ - prepare_func: ToolsPrepareFunc[AgentDepsT] + prepare_func: ToolsPrepareFunc[AgentDepsT] | None + tool_config: ToolConfig | None = None async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: original_tools = await super().get_tools(ctx) + + tools_after_tool_config = await self._get_tools_from_tool_config(original_tools) + # If tool config is not present we will get original tools back which we can then pass onto prepare function + tools_after_prepare_func = await self.get_tools_from_prepare_func(tools_after_tool_config, ctx) + # If prepare function is not present we will get tools_after_tool_config(which could be original tools if tool config was also not present) back which we can then return + + return tools_after_prepare_func + + async def _get_tools_from_tool_config( + self, original_tools: dict[str, ToolsetTool[AgentDepsT]] + ) -> dict[str, ToolsetTool[AgentDepsT]]: + if self.tool_config is None: + return original_tools + + tool_descriptions = self.tool_config.tool_descriptions + + for tool_name, description in tool_descriptions.items(): + if tool_name in original_tools: + original_tool = original_tools[tool_name] + updated_tool_def = replace(original_tool.tool_def, description=description) + original_tools[tool_name] = replace(original_tool, tool_def=updated_tool_def) + + return original_tools + + async def get_tools_from_prepare_func( + self, original_tools: dict[str, ToolsetTool[AgentDepsT]], ctx: RunContext[AgentDepsT] + ) -> dict[str, ToolsetTool[AgentDepsT]]: + if self.prepare_func is None: + return original_tools + original_tool_defs = [tool.tool_def for tool in original_tools.values()] prepared_tool_defs_by_name = { tool_def.name: tool_def for tool_def in (await self.prepare_func(ctx, original_tool_defs) or []) diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index 391cf06f2f..91a2980ef0 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -170,6 +170,7 @@ async def transform_stream( # noqa: C901 tool_call_id=tool_call_id, tool_name=tool_name, content='Final result processed.', + return_kind='final-result-processed', ) ) async for e in self.handle_function_tool_result(output_tool_result_event): diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index fe3513ae58..f03c8ee89f 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -142,6 +142,7 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: tool_name=tool_name, content=msg.content, tool_call_id=tool_call_id, + return_kind='tool-executed', ) ) diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index fa82b9255b..0e10951769 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -203,7 +203,11 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # if part.state == 'output-available': builder.add( - ToolReturnPart(tool_name=tool_name, tool_call_id=tool_call_id, content=part.output) + ToolReturnPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=part.output, + ) ) elif part.state == 'output-error': builder.add( diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 0e1230a7a5..f213e6749c 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1084,6 +1084,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1155,6 +1156,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -1184,6 +1186,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1301,6 +1304,7 @@ async def retrieve_entity_info(name: str) -> str: content="alice is bob's wife", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1308,6 +1312,7 @@ async def retrieve_entity_info(name: str) -> str: content="bob is alice's husband", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1315,6 +1320,7 @@ async def retrieve_entity_info(name: str) -> str: content="charlie is alice's son", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1322,6 +1328,7 @@ async def retrieve_entity_info(name: str) -> str: content="daisy is bob's daughter and charlie's younger sister", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ] @@ -1575,6 +1582,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='toolu_01WALUz3dC75yywrmL6dF3Bc', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -6417,6 +6425,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01X9wcHKKAZD9tBC711xipPa', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6455,6 +6464,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id='toolu_01LZABsgreMefH2Go8D5PQbW', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6526,6 +6536,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01JJ8TequDsrEU2pv1QFRWAK', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6619,6 +6630,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01ArHq5f2wxRpRF2PVQcKExM', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 185468021a..e729ed57ee 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -337,6 +337,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='30°C', tool_call_id='tooluse_5WEci1UmQ8ifMFkUcy2gHQ', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -368,6 +369,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='Final result processed.', tool_call_id='tooluse_9AjloJSaQDKmpPFff-2Clg', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -475,6 +477,7 @@ async def get_capital(country: str) -> str: tool_name='get_capital', tool_call_id='tooluse_F8LnaCMtQ0-chKTnPhNH2g', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -636,6 +639,7 @@ async def get_temperature(city: str) -> str: content='30°C', tool_call_id='tooluse_lAG_zP8QRHmSYOwZzzaCqA', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The')), diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index a1b7785801..46f6d71aaa 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -237,6 +237,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -330,6 +331,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -358,6 +360,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ee1aa83b15..dc44538fdd 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -642,6 +642,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -709,6 +710,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -732,12 +734,14 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='get_location', content='{"lat": 41, "lng": -74}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -916,10 +920,18 @@ async def bar(y: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='foo', + content='a', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='bar', + content='b', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -940,6 +952,7 @@ async def bar(y: str) -> str: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1013,6 +1026,7 @@ def get_location(loc_name: str) -> str: content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='function-tool-not-executed', ) ], run_id=IsStr(), @@ -1222,6 +1236,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1736,6 +1751,7 @@ async def bar() -> str: content='hello', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1765,6 +1781,7 @@ async def bar() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1820,6 +1837,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1849,6 +1867,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2133,6 +2152,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_google.py b/tests/models/test_google.py index be6d4bd68a..4b20ba8fab 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -228,7 +228,11 @@ async def temperature(city: str, date: datetime.date) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='temperature', content='30°C', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='temperature', + content='30°C', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -262,6 +266,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -602,6 +607,7 @@ async def get_capital(country: str) -> str: tool_name='get_capital', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -634,6 +640,7 @@ async def get_capital(country: str) -> str: content='Paris', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -759,7 +766,11 @@ async def get_temperature(city: str) -> str: IsInstance(FunctionToolCallEvent), FunctionToolResultEvent( result=ToolReturnPart( - tool_name='get_capital', content='Paris', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_capital', + content='Paris', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -777,7 +788,11 @@ async def get_temperature(city: str) -> str: IsInstance(FunctionToolCallEvent), FunctionToolResultEvent( result=ToolReturnPart( - tool_name='get_temperature', content='30°C', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_temperature', + content='30°C', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The temperature in Paris')), @@ -2470,6 +2485,7 @@ async def bar() -> str: content='hello', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2501,6 +2517,7 @@ async def bar() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2568,6 +2585,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2599,6 +2617,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2656,6 +2675,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2902,6 +2922,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3495,6 +3516,7 @@ class Animal(BaseModel): content='Please return text or call a tool.', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -4181,6 +4203,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4245,6 +4268,7 @@ def get_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -4544,6 +4568,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4596,6 +4621,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The')), diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 0f10b4b7e4..6bc77303ac 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -264,6 +264,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -359,6 +360,7 @@ async def get_location(loc_name: str) -> str: content='Wrong location, please try again', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -388,6 +390,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -529,6 +532,7 @@ async def test_stream_structured(allow_model_requests: None): content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -636,6 +640,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_wkpd', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5401,6 +5406,7 @@ async def get_something_by_name(name: str) -> str: content='Something with name: nonexistent', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='Be concise. Never use pretty double quotes, just regular ones.', @@ -5530,6 +5536,7 @@ async def get_something_by_name(name: str) -> str: content='Something with name: test_name', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='Be concise. Never use pretty double quotes, just regular ones.', diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index ed99de4e56..72483208ed 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -389,6 +389,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -417,6 +418,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -830,6 +832,7 @@ def response_validator(value: str) -> str: content='Response is invalid', tool_name=None, tool_call_id=IsStr(), + retry_message='Fix the errors and try again.', timestamp=IsNow(tz=timezone.utc), ) ], diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 424be8d39b..43b106ac56 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -468,6 +468,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -537,6 +538,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -603,6 +605,7 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1159,6 +1162,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -1188,6 +1192,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1319,6 +1324,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -1348,6 +1354,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1377,6 +1384,7 @@ async def get_location(loc_name: str) -> str: content='Final result processed.', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1483,6 +1491,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1506,6 +1515,7 @@ async def get_location(loc_name: str) -> str: content='Final result processed.', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1599,6 +1609,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1717,6 +1728,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -1746,6 +1758,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1950,6 +1963,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='GJYBCIkcS', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 196b140454..83b79533a2 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -187,6 +187,7 @@ def test_weather(): content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -205,6 +206,7 @@ def test_weather(): content='Raining', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -257,6 +259,7 @@ def test_var_args(): 'metadata': None, 'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc), # type: ignore[reportUnknownMemberType] 'part_kind': 'tool-return', + 'return_kind': 'tool-executed', } ) @@ -389,19 +392,39 @@ def test_call_all(): ModelRequest( parts=[ ToolReturnPart( - tool_name='foo', content='1', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='foo', + content='1', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='bar', content='2', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='bar', + content='2', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='baz', content='3', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='baz', + content='3', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='qux', content='4', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='qux', + content='4', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='quz', + content='a', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index f7a6809a71..07e8b88cf2 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -100,6 +100,7 @@ def test_custom_output_args(): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -147,6 +148,7 @@ class Foo(BaseModel): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -190,6 +192,7 @@ def test_output_type(): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -234,6 +237,7 @@ async def my_ret(x: int) -> str: tool_name='my_ret', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -248,7 +252,11 @@ async def my_ret(x: int) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_ret', content='1', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_ret', + content='1', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 74cb3c1414..b90d9ed774 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -266,6 +266,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -367,6 +368,7 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -400,6 +402,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -929,6 +932,7 @@ async def get_image() -> ImageUrl: content='See file bd38f5', tool_call_id='call_4hrT4QP9jfojtK69vGiFCFjG', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1018,6 +1022,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_Btn0GIzGr4ugNlLmkQghQUMY', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -2133,7 +2138,11 @@ async def get_temperature(city: str) -> float: ModelRequest( parts=[ ToolReturnPart( - tool_name='get_temperature', content=20.0, tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_temperature', + content=20.0, + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='You are a helpful assistant.', @@ -2556,6 +2565,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2594,6 +2604,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2658,6 +2669,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2746,6 +2758,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2836,6 +2849,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2926,6 +2940,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3016,6 +3031,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index b03e99bb91..31d5d00d84 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -320,12 +320,14 @@ async def get_location(loc_name: str) -> str: tool_name='get_location', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='get_location', content='{"lat": 51, "lng": 0}', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -404,6 +406,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_FLm3B1f8QAan0KpbUXhNY8bA', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1438,6 +1441,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1468,6 +1472,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1529,6 +1534,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1610,6 +1616,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_tTAThu8l2S9hNky2krdwijGP', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1693,6 +1700,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1772,6 +1780,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1855,6 +1864,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2425,6 +2435,7 @@ def update_plan(plan: str) -> str: content='plan updated', tool_call_id='call_gL7JE6GDeGGsFubqO2XGytyO', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions="You are a helpful assistant that uses planning. You MUST use the update_plan tool and continually update it as you make progress against the user's prompt", @@ -3699,6 +3710,7 @@ def get_meaning_of_life() -> int: content=42, tool_call_id='call_3WCunBU7lCG1HHaLmnnRJn8I', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5950,6 +5962,7 @@ async def test_openai_responses_image_generation_tool_without_image_output( content='Please return text or call a tool.', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -6113,6 +6126,7 @@ class Animal(BaseModel): content='Please return text or include your response in a tool call.', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -6149,6 +6163,7 @@ class Animal(BaseModel): content='Final result processed.', tool_call_id='call_eE7MHM5WMJnMt5srV69NmBJk', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6364,6 +6379,7 @@ async def get_animal() -> str: content='axolotl', tool_call_id='call_t76xO1K2zqrJkawkU3tur8vj', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6676,6 +6692,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='call_LIXPi261Xx3dGYzlDsOoyHGk', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 4e5f74f476..edf9e6045e 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -638,6 +638,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ), UserPromptPart(content='Second message', timestamp=IsDatetime()), ], diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 5cbf85fc69..ff091752e1 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1555,12 +1555,14 @@ async def test_messages() -> None: content='Tool message', tool_call_id='tool_call_1', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='tool_call_2', content='Tool message', tool_call_id='tool_call_2', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content='User message', diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ce2d91c54..49a917911b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -68,6 +68,7 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput +from pydantic_ai.prompt_config import PromptConfig, PromptTemplates, ToolConfig from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied @@ -226,6 +227,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -235,6 +237,374 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') +def test_prompt_config_callable(): + """Test all prompt templates: validation_errors_retry, final_result_processed, output_tool_not_executed, and function_tool_not_executed.""" + + def my_function_tool() -> str: # pragma: no cover + return 'function executed' + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, '{"a": "wrong", "b": "foo"}')]) + + else: + assert info.function_tools is not None + return ModelResponse( + parts=[ + ToolCallPart(info.output_tools[0].name, '{"a": 42, "b": "foo"}'), # Succeeds + ToolCallPart(info.output_tools[0].name, '{"a": 99, "b": "bar"}'), # Not executed + ToolCallPart(info.function_tools[0].name, '{}'), # Not executed + ] + ) + + agent = Agent( + FunctionModel(return_model), + output_type=Foo, + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry=lambda part, ctx: 'Please fix these validation errors and try again.', + final_result_processed=lambda part, ctx: f'Custom final result {part.content}', + output_tool_not_executed=lambda part, ctx: f'Custom output not executed: {part.tool_name}', + function_tool_not_executed=lambda part, ctx: f'Custom function not executed: {part.tool_name}', + ) + ), + ) + + agent.tool_plain(my_function_tool) + + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Please fix these validation errors and try again.\ +""") + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='final_result', + content=[ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ], + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + retry_message='Please fix these validation errors and try again.', + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='final_result', args='{"a": 99, "b": "bar"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='my_function_tool', args='{}', tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=91, output_tokens=23), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Custom final result Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', + ), + ToolReturnPart( + tool_name='final_result', + content='Custom output not executed: final_result', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', + ), + ToolReturnPart( + tool_name='my_function_tool', + content='Custom function not executed: my_function_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', + ), + ], + run_id=IsStr(), + ), + ] + ) + + +def test_prompt_config_string_and_override_prompt_config(): + """Test all prompt templates: validation_errors_retry, final_result_processed, output_tool_not_executed, and function_tool_not_executed.""" + + def my_function_tool() -> str: # pragma: no cover + return 'function executed' + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, '{"a": "wrong", "b": "foo"}')]) + + else: + assert info.function_tools is not None + return ModelResponse( + parts=[ + ToolCallPart(info.output_tools[0].name, '{"a": 42, "b": "foo"}'), # Succeeds + ToolCallPart(info.output_tools[0].name, '{"a": 99, "b": "bar"}'), # Not executed + ToolCallPart(info.function_tools[0].name, '{}'), # Not executed + ] + ) + + agent = Agent( + FunctionModel(return_model), + output_type=Foo, + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Custom retry message', + final_result_processed='Custom final result', + output_tool_not_executed='Custom output not executed:', + function_tool_not_executed='Custom function not executed', + ) + ), + ) + + agent.tool_plain(my_function_tool) + + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Custom retry message""") + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='final_result', + content=[ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ], + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + retry_message='Custom retry message', + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='final_result', args='{"a": 99, "b": "bar"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='my_function_tool', args='{}', tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=85, output_tokens=23), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Custom final result', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', + ), + ToolReturnPart( + tool_name='final_result', + content='Custom output not executed:', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', + ), + ToolReturnPart( + tool_name='my_function_tool', + content='Custom function not executed', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', + ), + ], + run_id=IsStr(), + ), + ] + ) + + # Verify prompt_config can be overridden + with agent.override( + prompt_config=PromptConfig(templates=PromptTemplates(validation_errors_retry='Custom retry message override')) + ): + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Custom retry message override""") + + +def test_prompt_config_tool_config_descriptions(): + """Test that ToolConfig.tool_descriptions updates tool descriptions at the agent level.""" + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Verify the tool description was updated + assert info.function_tools is not None + my_tool = next(t for t in info.function_tools if t.name == 'my_tool') + assert my_tool.description == 'Custom tool description from ToolConfig' + return ModelResponse(parts=[TextPart('Done')]) + + agent = Agent( + FunctionModel(return_model), + prompt_config=PromptConfig( + tool_config=ToolConfig(tool_descriptions={'my_tool': 'Custom tool description from ToolConfig'}) + ), + ) + + @agent.tool_plain + def my_tool(x: int) -> int: # pragma: no cover + """Original description that should be overridden""" + return x * 2 + + result = agent.run_sync('Hello') + assert result.output == 'Done' + + +def test_prompt_config_tool_config_descriptions_at_runtime(): + """Test that ToolConfig.tool_descriptions passed to run_sync() overrides agent-level prompt_config.""" + observed_descriptions: list[str | None] = [] + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.function_tools is not None + basic_tool = next(tool for tool in info.function_tools if tool.name == 'basic_tool') + observed_descriptions.append(basic_tool.description) + return ModelResponse(parts=[TextPart('Done')]) + + # Agent with agent-level prompt_config + agent = Agent( + FunctionModel(return_model), + prompt_config=PromptConfig( + tool_config=ToolConfig( + tool_descriptions={ + 'basic_tool': 'Agent-level tool description', + 'not_present_basic_tool': 'Should not be used', + } + ) + ), + ) + + @agent.tool_plain + def basic_tool(x: int) -> int: # pragma: no cover + """Original description that should be overridden""" + return x * 2 + + # First run: no runtime prompt_config, should use agent-level description + result = agent.run_sync('Hello') + assert result.output == 'Done' + assert observed_descriptions[-1] == 'Agent-level tool description' + + # Second run: pass runtime prompt_config, should override agent-level description + result = agent.run_sync( + 'Hello', + prompt_config=PromptConfig( + tool_config=ToolConfig(tool_descriptions={'basic_tool': 'Runtime custom tool description'}) + ), + ) + assert result.output == 'Done' + assert observed_descriptions[-1] == 'Runtime custom tool description' + + def test_result_pydantic_model_validation_error(): def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None @@ -289,7 +659,8 @@ def check_b(cls, v: str) -> str: ] ``` -Fix the errors and try again.""") +Fix the errors and try again.\ +""") def test_output_validator(): @@ -352,6 +723,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -497,6 +869,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -508,7 +881,11 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', content='foobar', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='final_result', + content='foobar', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -522,6 +899,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1096,6 +1474,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1149,9 +1528,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ RetryPromptPart( - content='City not found, I only know Mexico City', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -1407,6 +1784,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1445,6 +1823,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2028,9 +2407,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ RetryPromptPart( - content='City not found, I only know Mexico City', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -2075,7 +2452,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2111,7 +2492,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2176,7 +2561,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2240,7 +2629,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2265,6 +2658,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2292,7 +2686,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2311,6 +2709,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -2336,6 +2735,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -2461,6 +2861,7 @@ def test_tool() -> str: content='Test response', tool_call_id='call_123', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3023,24 +3424,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3113,12 +3518,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -3169,12 +3576,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -3259,18 +3668,21 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -3283,6 +3695,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3370,18 +3783,21 @@ def regular_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='external_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3458,6 +3874,7 @@ def regular_tool(x: int) -> int: content=1, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3513,6 +3930,7 @@ def regular_tool(x: int) -> int: content=0, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3537,6 +3955,7 @@ def regular_tool(x: int) -> int: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -3622,21 +4041,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content=42, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='another_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -3649,6 +4075,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3721,12 +4148,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -3805,6 +4234,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -3878,6 +4308,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', @@ -3957,6 +4388,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), RetryPromptPart( content='Second output validation failed', @@ -4052,6 +4484,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id='second', + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -4136,6 +4569,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4443,6 +4877,7 @@ async def foobar(x: str) -> str: content='inner agent result', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4764,6 +5199,7 @@ def get_image() -> BinaryContent: content='See file image_id_1', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -4812,6 +5248,7 @@ def get_files(): content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'], tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5039,6 +5476,7 @@ class Output(BaseModel): content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -5088,7 +5526,11 @@ def my_tool(x: int) -> int: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5106,7 +5548,11 @@ def my_tool(x: int) -> int: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_tool', content=4, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_tool', + content=4, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5178,6 +5624,7 @@ def foo_tool(foo: Foo) -> int: 'tool_call_id': IsStr(), 'timestamp': IsStr(), 'part_kind': 'retry-prompt', + 'retry_message': 'Fix the errors and try again.', } ], 'instructions': None, @@ -5276,6 +5723,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5357,6 +5805,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -5653,6 +6102,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='foo tool added', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5671,6 +6121,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Hello from foo', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5750,6 +6201,7 @@ async def only_if_plan_presented( content='a', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5774,6 +6226,7 @@ async def only_if_plan_presented( content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6061,9 +6514,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ModelRequest( parts=[ RetryPromptPart( - content='Please return text or call a tool.', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='Please return text or call a tool.', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -6102,7 +6553,12 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon model = FunctionModel(model_function) - agent = Agent(model, output_type=[str, DeferredToolRequests]) + # Using prompt_config without setting tool_call_denied to cover line 78 in prompt_config.py + agent = Agent( + model, + output_type=[str, DeferredToolRequests], + prompt_config=PromptConfig(templates=PromptTemplates(tool_call_denied='Tool call denied custom message.')), + ) @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> str: @@ -6151,6 +6607,7 @@ def create_file(path: str, content: str) -> str: content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6209,6 +6666,7 @@ def create_file(path: str, content: str) -> str: content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6220,19 +6678,21 @@ def create_file(path: str, content: str) -> str: content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', - content='File cannot be deleted', + content='Tool call denied custom message.', tool_call_id='never_delete', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), ), ModelResponse( parts=[TextPart(content='Done!')], - usage=RequestUsage(input_tokens=78, output_tokens=24), + usage=RequestUsage(input_tokens=80, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), run_id=IsStr(), @@ -6250,19 +6710,21 @@ def create_file(path: str, content: str) -> str: content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', - content='File cannot be deleted', + content='Tool call denied custom message.', tool_call_id='never_delete', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), ), ModelResponse( parts=[TextPart(content='Done!')], - usage=RequestUsage(input_tokens=78, output_tokens=24), + usage=RequestUsage(input_tokens=80, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), run_id=IsStr(), @@ -6440,6 +6902,7 @@ def update_file(ctx: RunContext, path: str, content: str) -> str: content="File '.env' updated", tool_call_id='update_file_1', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content='continue with the operation', timestamp=IsDatetime()), ], @@ -6852,6 +7315,7 @@ def roll_dice() -> int: content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6876,6 +7340,7 @@ def roll_dice() -> int: content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6910,6 +7375,7 @@ def roll_dice() -> int: content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6934,6 +7400,7 @@ def roll_dice() -> int: content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), diff --git a/tests/test_agent_output_schemas.py b/tests/test_agent_output_schemas.py index 5c63343126..c187a9eb7f 100644 --- a/tests/test_agent_output_schemas.py +++ b/tests/test_agent_output_schemas.py @@ -433,6 +433,22 @@ async def test_deferred_output_json_schema(): } ) + +def test_build_instructions_appends_schema_placeholder(): + """Test that build_instructions appends {schema} when template doesn't contain it.""" + from pydantic_ai._output import OutputObjectDefinition, PromptedOutputSchema + + object_def = OutputObjectDefinition( + json_schema={'type': 'object', 'properties': {'name': {'type': 'string'}}}, + name='TestOutput', + description='A test output', + ) + template_without_schema = 'Please respond with JSON.' + + result = PromptedOutputSchema.build_instructions(template_without_schema, object_def) + assert result == snapshot( + 'Please respond with JSON.\n\n{"type": "object", "properties": {"name": {"type": "string"}}, "title": "TestOutput", "description": "A test output"}' + ) # special case of only BinaryImage and DeferredToolRequests agent = Agent('test', output_type=[BinaryImage, DeferredToolRequests]) assert agent.output_json_schema() == snapshot( diff --git a/tests/test_dbos.py b/tests/test_dbos.py index de99f1b1d0..fbec58a658 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -371,7 +371,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -386,7 +386,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -450,7 +450,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -672,6 +672,7 @@ async def event_stream_handler( content='Mexico', tool_call_id='call_q2UyBRP7eXNTzAoR8lEhjc9Z', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), FunctionToolResultEvent( @@ -680,6 +681,7 @@ async def event_stream_handler( content='Pydantic AI', tool_call_id='call_b51ijcpFkDiTQG1bQzsrmtW5', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -721,6 +723,7 @@ async def event_stream_handler( content='sunny', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -1446,12 +1449,14 @@ async def hitl_main_loop(prompt: str) -> AgentRunResult[str | DeferredToolReques content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -1579,12 +1584,14 @@ def hitl_main_loop_sync(prompt: str) -> AgentRunResult[str | DeferredToolRequest content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -1720,6 +1727,7 @@ async def test_dbos_agent_with_model_retry(allow_model_requests: None, dbos: DBO content='sunny', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_examples.py b/tests/test_examples.py index 8ed0828250..b60f3d0ee2 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -324,6 +324,7 @@ async def call_tool( text_responses: dict[str, str | ToolCallPart | Sequence[ToolCallPart]] = { + 'Hello': 'Hello! How can I help you today?', 'Use the web to get the current time.': "In San Francisco, it's 8:21:41 pm PDT on Wednesday, August 6, 2025.", 'Give me a sentence with the biggest news in AI this week.': 'Scientists have developed a universal AI detector that can identify deepfake videos.', 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 02bab17cc3..977d61fa6e 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -262,6 +262,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) content=32.0, tool_call_id='call_QssdxTGkPblTYHmyVES1tKBj', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -433,6 +434,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): content='The weather in Mexico City is sunny and 26 degrees Celsius.', tool_call_id='call_m9goNwaHBbU926w47V7RtWPt', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -515,6 +517,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A content='Pydantic AI', tool_call_id='call_LaiWltzI39sdquflqeuF0EyE', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -593,6 +596,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age content='Pydantic AI\n', tool_call_id='call_qi5GtBeIEyT7Y3yJvVFIi062', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -673,6 +677,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: content='See file 1c8566', tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 1c8566:', image_content], timestamp=IsDatetime()), ], @@ -760,6 +765,7 @@ async def test_tool_returning_image_resource_link( content='See file 1c8566', tool_call_id='call_eVFgn54V9Nuh8Y4zvuzkYjUp', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 1c8566:', image_content], timestamp=IsDatetime()), ], @@ -828,6 +834,7 @@ async def test_tool_returning_audio_resource( content='See file 2d36ae', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 2d36ae:', audio_content], timestamp=IsDatetime()), ], @@ -900,6 +907,7 @@ async def test_tool_returning_audio_resource_link( content='See file 2d36ae', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -981,6 +989,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im content='See file 1c8566', tool_call_id='call_Q7xG8CCG0dyevVfUS0ubsDdN', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1060,6 +1069,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): content={'foo': 'bar', 'baz': 123}, tool_call_id='call_oqKviITBj8PwpQjGyUu4Zu5x', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1136,6 +1146,7 @@ async def test_tool_returning_unstructured_dict(allow_model_requests: None, agen content={'foo': 'bar', 'baz': 123}, tool_call_id='call_R0n2R7S9vL2aZOX25T9jahTd', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1216,6 +1227,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): tool_name='get_error', tool_call_id='call_rETXZWddAGZSHyVHAxptPGgc', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -1254,6 +1266,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): content='This is not an error', tool_call_id='call_4xGyvdghYKHN8x19KWkRtA5N', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1330,6 +1343,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): content=[], tool_call_id='call_mJTuQ2Cl5SaHPTJbIILEUhJC', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1415,6 +1429,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ], tool_call_id='call_kL0TvjEVQBDGZrn1Zv7iNYOW', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ diff --git a/tests/test_prefect.py b/tests/test_prefect.py index b1c18b9803..715c80f251 100644 --- a/tests/test_prefect.py +++ b/tests/test_prefect.py @@ -305,7 +305,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_rI3WKPYvVwlOgCGRjsPP2hEx","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_rI3WKPYvVwlOgCGRjsPP2hEx","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], @@ -389,7 +389,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_NS4iQj14cDFwc0BnrKqDHavt","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_NS4iQj14cDFwc0BnrKqDHavt","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], @@ -406,7 +406,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_SkGkkGDvHQEEk0CGbnAh2AQw","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_SkGkkGDvHQEEk0CGbnAh2AQw","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9149d19d1b..b6f944ea9e 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -87,7 +87,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -123,7 +127,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -179,7 +187,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -225,7 +237,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -606,6 +622,7 @@ async def ret_a(x: str) -> str: content='hello world', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -633,6 +650,7 @@ async def ret_a(x: str) -> str: content='hello world', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -657,6 +675,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -826,24 +845,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -918,12 +941,14 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -970,12 +995,14 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -1076,18 +1103,21 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -1100,6 +1130,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1199,12 +1230,14 @@ def regular_tool(x: int) -> int: # pragma: no cover content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='output-tool-not-executed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1279,6 +1312,7 @@ def regular_tool(x: int) -> int: content=1, tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1338,6 +1372,7 @@ def regular_tool(x: int) -> int: content=0, tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1363,6 +1398,7 @@ def regular_tool(x: int) -> int: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1443,21 +1479,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content=42, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='another_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -1470,6 +1513,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1544,12 +1588,14 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -1705,6 +1751,7 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', @@ -1789,6 +1836,7 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ), RetryPromptPart( content='Second output validation failed', @@ -2100,6 +2148,7 @@ def known_tool(x: int) -> int: content=10, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ), ] @@ -2346,6 +2395,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: content=84, tool_call_id='my_tool', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2529,6 +2579,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2575,6 +2626,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2624,6 +2676,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2667,6 +2720,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='See file bd38f5', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), content=[ 'This is file bd38f5:', @@ -2716,6 +2770,7 @@ async def ret_a(x: str) -> str: content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 24ccefb83e..af88a8f355 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -455,7 +455,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -484,7 +484,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -578,7 +578,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -811,6 +811,7 @@ async def event_stream_handler( content='Mexico', tool_call_id='call_q2UyBRP7eXNTzAoR8lEhjc9Z', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), FunctionToolResultEvent( @@ -819,6 +820,7 @@ async def event_stream_handler( content='Pydantic AI', tool_call_id='call_b51ijcpFkDiTQG1bQzsrmtW5', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -860,6 +862,7 @@ async def event_stream_handler( content='sunny', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -1875,12 +1878,14 @@ async def test_temporal_agent_with_hitl_tool(allow_model_requests: None, client: content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -1996,6 +2001,7 @@ async def test_temporal_agent_with_model_retry(allow_model_requests: None, clien tool_name='get_weather_in_city', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -2034,6 +2040,7 @@ async def test_temporal_agent_with_model_retry(allow_model_requests: None, clien content='sunny', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_tools.py b/tests/test_tools.py index 0031f702cd..fe58040b2a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1397,6 +1397,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: content=84, tool_call_id='my_tool', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1797,12 +1798,14 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', tool_name='get_price', tool_call_id='get_price_banana', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='get_price', @@ -1810,12 +1813,14 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', tool_name='get_price', tool_call_id='get_price_grape', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), UserPromptPart( content='The price of apple is 10.0.', @@ -1890,12 +1895,14 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', tool_name='get_price', tool_call_id='get_price_banana', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='get_price', @@ -1903,12 +1910,14 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', tool_name='get_price', tool_call_id='get_price_grape', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), UserPromptPart( content='The price of apple is 10.0.', @@ -1928,6 +1937,7 @@ def buy(fruit: str): tool_name='buy', tool_call_id='buy_apple', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='buy', @@ -1935,12 +1945,14 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', tool_name='buy', tool_call_id='buy_pear', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), UserPromptPart( content='I bought a banana', @@ -1968,6 +1980,7 @@ def buy(fruit: str): tool_name='buy', tool_call_id='buy_apple', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='buy', @@ -1975,12 +1988,14 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', tool_name='buy', tool_call_id='buy_pear', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), UserPromptPart( content='I bought a banana', @@ -2033,12 +2048,14 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', tool_name='get_price', tool_call_id='get_price_banana', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='get_price', @@ -2046,18 +2063,21 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', tool_name='get_price', tool_call_id='get_price_grape', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), RetryPromptPart( content='Apples are not available', tool_name='buy', tool_call_id='buy_apple', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), ToolReturnPart( tool_name='buy', @@ -2065,12 +2085,14 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', tool_name='buy', tool_call_id='buy_pear', timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ), UserPromptPart( content='The price of apple is 10.0.', @@ -2185,6 +2207,7 @@ def bar(x: int) -> int: content=9, tool_call_id='bar', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2238,6 +2261,7 @@ def bar(x: int) -> int: content=9, tool_call_id='bar', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2249,12 +2273,14 @@ def bar(x: int) -> int: content=2, tool_call_id='foo1', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='foo', content='The tool call was denied.', tool_call_id='foo2', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), @@ -2312,6 +2338,7 @@ def test_deferred_tool_results_serializable(): 'tool_call_id': 'foo', 'timestamp': IsDatetime(), 'part_kind': 'retry-prompt', + 'retry_message': 'Fix the errors and try again.', }, 'any': {'foo': 'bar'}, }, @@ -2420,6 +2447,7 @@ def always_fail(ctx: RunContext[None]) -> str: tool_name='always_fail', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -2438,6 +2466,7 @@ def always_fail(ctx: RunContext[None]) -> str: tool_name='always_fail', tool_call_id=IsStr(), timestamp=IsDatetime(), + retry_message='Fix the errors and try again.', ) ], run_id=IsStr(), @@ -2456,6 +2485,7 @@ def always_fail(ctx: RunContext[None]) -> str: content='I guess you never learn', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index ac17fd0be5..c49cce3c97 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -123,6 +123,7 @@ async def ret_a(x: str) -> str: content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(),