diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 96d4d23766..cd1650a777 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -720,8 +720,7 @@ async def _consume_stream(): ) as stream_result: yield stream_result - async_result = _utils.get_event_loop().run_until_complete(anext(_consume_stream())) - return result.StreamedRunResultSync(async_result) + return result.StreamedRunResultSync(_consume_stream()) @overload def run_stream_events( diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 88bfe407fa..d0b1b5ed97 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,6 +1,8 @@ from __future__ import annotations as _annotations +import inspect from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator +from contextlib import asynccontextmanager from copy import deepcopy from dataclasses import dataclass, field, replace from datetime import datetime @@ -583,10 +585,17 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]): """Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods.""" - _streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] + _streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] | None = None - def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None: - self._streamed_run_result = streamed_run_result + def __init__( + self, + streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] + | AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]], + ) -> None: + if isinstance(streamed_run_result, StreamedRunResult): + self._streamed_run_result = streamed_run_result + else: + self._stream = streamed_run_result def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return the history of messages. @@ -600,7 +609,9 @@ def all_messages(self, *, output_tool_return_content: str | None = None) -> list Returns: List of messages. """ - return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.all_messages(output_tool_return_content=output_tool_return_content) + ) def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes. @@ -614,7 +625,9 @@ def all_messages_json(self, *, output_tool_return_content: str | None = None) -> Returns: JSON bytes representing the messages. """ - return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.all_messages_json(output_tool_return_content=output_tool_return_content) + ) def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return new messages associated with this run. @@ -630,7 +643,9 @@ def new_messages(self, *, output_tool_return_content: str | None = None) -> list Returns: List of new messages. """ - return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.new_messages(output_tool_return_content=output_tool_return_content) + ) def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes. @@ -644,7 +659,9 @@ def new_messages_json(self, *, output_tool_return_content: str | None = None) -> Returns: JSON bytes representing the new messages. """ - return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content) + return self._async_to_sync( + lambda result: result.new_messages_json(output_tool_return_content=output_tool_return_content) + ) def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: """Stream the output as an iterable. @@ -661,7 +678,7 @@ def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDa Returns: An iterable of the response data. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_output(debounce_by=debounce_by)) def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: """Stream the text result as an iterable. @@ -676,7 +693,7 @@ def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) - Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_text(delta=delta, debounce_by=debounce_by)) def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]: """Stream the response as an iterable of Structured LLM Messages. @@ -689,16 +706,59 @@ def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple Returns: An iterable of the structured response message and whether that is the last message. """ - return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by)) + return self._async_iterator_to_sync(lambda result: result.stream_responses(debounce_by=debounce_by)) def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" - return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output()) + return self._async_to_sync(lambda result: result.get_output()) + + @asynccontextmanager + async def _with_streamed_run_result(self) -> AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]]: + clean_up = False + if self._streamed_run_result is None: + clean_up = True + self._streamed_run_result = await anext(self._stream) + + yield self._streamed_run_result + + if clean_up: + try: + await anext(self._stream) + except StopAsyncIteration: + pass + + def _async_iterator_to_sync( + self, + func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], AsyncIterator[T]], + ) -> Iterator[T]: + async def my_task(): + try: + async with self._with_streamed_run_result() as result: + async for item in func(result): + yield item + except RuntimeError as e: + if str(e) != 'Attempted to exit cancel scope in a different task than it was entered in': + raise + + return _utils.sync_async_iterator(my_task()) + + def _async_to_sync( + self, + func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], Awaitable[T] | T], + ) -> T: + async def my_task(): + async with self._with_streamed_run_result() as result: + res = func(result) + if inspect.isawaitable(res): + res = cast(T, await res) + return res + + return _utils.get_event_loop().run_until_complete(my_task()) @property def response(self) -> _messages.ModelResponse: """Return the current state of the response.""" - return self._streamed_run_result.response + return self._async_to_sync(lambda result: result.response) def usage(self) -> RunUsage: """Return the usage of the whole run. @@ -706,22 +766,20 @@ def usage(self) -> RunUsage: !!! note This won't return the full usage until the stream is finished. """ - return self._streamed_run_result.usage() + return self._async_to_sync(lambda result: result.usage()) def timestamp(self) -> datetime: """Get the timestamp of the response.""" - return self._streamed_run_result.timestamp() + return self._async_to_sync(lambda result: result.timestamp()) @property def run_id(self) -> str: """The unique identifier for the agent run.""" - return self._streamed_run_result.run_id + return self._async_to_sync(lambda result: result.run_id) def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT: """Validate a structured result message.""" - return _utils.get_event_loop().run_until_complete( - self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial) - ) + return self._async_to_sync(lambda result: result.validate_response_output(message, allow_partial=allow_partial)) @property def is_complete(self) -> bool: @@ -733,7 +791,7 @@ def is_complete(self) -> bool: [`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or [`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes. """ - return self._streamed_run_result.is_complete + return self._async_to_sync(lambda result: result.is_complete) @dataclass(repr=False) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 0c6a46f3c0..663625a407 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -190,7 +190,7 @@ async def ret_a(x: str) -> str: RunUsage( requests=2, input_tokens=103, - output_tokens=5, + output_tokens=11, tool_calls=1, ) )