Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,7 @@ async def _consume_stream():
) as stream_result:
yield stream_result

async_result = _utils.get_event_loop().run_until_complete(anext(_consume_stream()))
return result.StreamedRunResultSync(async_result)
return result.StreamedRunResultSync(_consume_stream())

@overload
def run_stream_events(
Expand Down
96 changes: 77 additions & 19 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations as _annotations

import inspect
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
from contextlib import asynccontextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from datetime import datetime
Expand Down Expand Up @@ -583,10 +585,17 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None
class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]):
"""Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods."""

_streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
_streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT] | None = None

def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None:
self._streamed_run_result = streamed_run_result
def __init__(
self,
streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
| AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]],
) -> None:
if isinstance(streamed_run_result, StreamedRunResult):
self._streamed_run_result = streamed_run_result
else:
self._stream = streamed_run_result

def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
"""Return the history of messages.
Expand All @@ -600,7 +609,9 @@ def all_messages(self, *, output_tool_return_content: str | None = None) -> list
Returns:
List of messages.
"""
return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content)
return self._async_to_sync(
lambda result: result.all_messages(output_tool_return_content=output_tool_return_content)
)

def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
"""Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes.
Expand All @@ -614,7 +625,9 @@ def all_messages_json(self, *, output_tool_return_content: str | None = None) ->
Returns:
JSON bytes representing the messages.
"""
return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content)
return self._async_to_sync(
lambda result: result.all_messages_json(output_tool_return_content=output_tool_return_content)
)

def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
"""Return new messages associated with this run.
Expand All @@ -630,7 +643,9 @@ def new_messages(self, *, output_tool_return_content: str | None = None) -> list
Returns:
List of new messages.
"""
return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content)
return self._async_to_sync(
lambda result: result.new_messages(output_tool_return_content=output_tool_return_content)
)

def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
"""Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes.
Expand All @@ -644,7 +659,9 @@ def new_messages_json(self, *, output_tool_return_content: str | None = None) ->
Returns:
JSON bytes representing the new messages.
"""
return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content)
return self._async_to_sync(
lambda result: result.new_messages_json(output_tool_return_content=output_tool_return_content)
)

def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
"""Stream the output as an iterable.
Expand All @@ -661,7 +678,7 @@ def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDa
Returns:
An iterable of the response data.
"""
return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by))
return self._async_iterator_to_sync(lambda result: result.stream_output(debounce_by=debounce_by))

def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
"""Stream the text result as an iterable.
Expand All @@ -676,7 +693,7 @@ def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -
Debouncing is particularly important for long structured responses to reduce the overhead of
performing validation as each token is received.
"""
return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by))
return self._async_iterator_to_sync(lambda result: result.stream_text(delta=delta, debounce_by=debounce_by))

def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]:
"""Stream the response as an iterable of Structured LLM Messages.
Expand All @@ -689,39 +706,80 @@ def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple
Returns:
An iterable of the structured response message and whether that is the last message.
"""
return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by))
return self._async_iterator_to_sync(lambda result: result.stream_responses(debounce_by=debounce_by))

def get_output(self) -> OutputDataT:
"""Stream the whole response, validate and return it."""
return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output())
return self._async_to_sync(lambda result: result.get_output())

@asynccontextmanager
async def _with_streamed_run_result(self) -> AsyncIterator[StreamedRunResult[AgentDepsT, OutputDataT]]:
clean_up = False
if self._streamed_run_result is None:
clean_up = True
self._streamed_run_result = await anext(self._stream)

yield self._streamed_run_result

if clean_up:
try:
await anext(self._stream)
except StopAsyncIteration:
pass

def _async_iterator_to_sync(
self,
func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], AsyncIterator[T]],
) -> Iterator[T]:
async def my_task():
try:
async with self._with_streamed_run_result() as result:
async for item in func(result):
yield item
except RuntimeError as e:
if str(e) != 'Attempted to exit cancel scope in a different task than it was entered in':
raise

return _utils.sync_async_iterator(my_task())

def _async_to_sync(
self,
func: Callable[[StreamedRunResult[AgentDepsT, OutputDataT]], Awaitable[T] | T],
) -> T:
async def my_task():
async with self._with_streamed_run_result() as result:
res = func(result)
if inspect.isawaitable(res):
res = cast(T, await res)
return res

return _utils.get_event_loop().run_until_complete(my_task())

@property
def response(self) -> _messages.ModelResponse:
"""Return the current state of the response."""
return self._streamed_run_result.response
return self._async_to_sync(lambda result: result.response)

def usage(self) -> RunUsage:
"""Return the usage of the whole run.

!!! note
This won't return the full usage until the stream is finished.
"""
return self._streamed_run_result.usage()
return self._async_to_sync(lambda result: result.usage())

def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._streamed_run_result.timestamp()
return self._async_to_sync(lambda result: result.timestamp())

@property
def run_id(self) -> str:
"""The unique identifier for the agent run."""
return self._streamed_run_result.run_id
return self._async_to_sync(lambda result: result.run_id)

def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
"""Validate a structured result message."""
return _utils.get_event_loop().run_until_complete(
self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial)
)
return self._async_to_sync(lambda result: result.validate_response_output(message, allow_partial=allow_partial))

@property
def is_complete(self) -> bool:
Expand All @@ -733,7 +791,7 @@ def is_complete(self) -> bool:
[`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or
[`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes.
"""
return self._streamed_run_result.is_complete
return self._async_to_sync(lambda result: result.is_complete)


@dataclass(repr=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def ret_a(x: str) -> str:
RunUsage(
requests=2,
input_tokens=103,
output_tokens=5,
output_tokens=11,
tool_calls=1,
)
)
Expand Down
Loading