Skip to content

Commit

Permalink
Changes for async mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanFauble committed Jan 14, 2025
1 parent 15db364 commit fa888ff
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 162 deletions.
130 changes: 86 additions & 44 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from synapseclient.core.async_utils import async_to_sync, otel_trace_method
from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST
from synapseclient.models.mixins.asynchronous_job import AsynchronousJob
from synapseclient.models.mixins.asynchronous_job import AsynchronousCommunicator
from synapseclient.models.protocols.agent_protocol import (
AgentSessionSynchronousProtocol,
AgentSynchronousProtocol,
Expand Down Expand Up @@ -50,13 +50,15 @@ class AgentSessionAccessLevel(str, Enum):


@dataclass
class AgentPrompt:
class AgentPrompt(AsynchronousCommunicator):
"""Represents a prompt, response, and metadata within an AgentSession.
Attributes:
id: The unique ID of the agent prompt.
session_id: The ID of the session that the prompt is associated with.
prompt: The prompt to send to the agent.
response: The response from the agent.
enable_trace: Whether tracing is enabled for the prompt.
trace: The trace of the agent session.
"""

Expand All @@ -65,20 +67,67 @@ class AgentPrompt:
id: Optional[str] = None
"""The unique ID of the agent prompt."""

session_id: Optional[str] = None
"""The ID of the session that the prompt is associated with."""

prompt: Optional[str] = None
"""The prompt sent to the agent."""

response: Optional[str] = None
"""The response from the agent."""

enable_trace: Optional[bool] = False
"""Whether tracing is enabled for the prompt."""

trace: Optional[str] = None
"""The trace or "thought process" of the agent when responding to the prompt."""

def to_synapse_request(self):
"""Converts the request to a request expected of the Synapse REST API."""
return {
"concreteType": self.concrete_type,
"sessionId": self.session_id,
"chatText": self.prompt,
"enableTrace": self.enable_trace,
}

def fill_from_dict(self, synapse_response: Dict[str, str]) -> "AgentPrompt":
"""
Converts a response from the REST API into this dataclass.
Arguments:
agent_prompt: The response from the REST API.
Returns:
The AgentPrompt object.
"""
self.id = synapse_response.get("sessionId", None)
self.response = synapse_response.get("responseText", None)
return self

async def _post_exchange_async(
self, *, synapse_client: Optional[Synapse] = None, **kwargs
) -> None:
"""Retrieves information about the trace of this prompt with the agent.
Arguments:
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
if self.enable_trace:
trace_response = await get_trace(
prompt_id=self.id,
newer_than=kwargs.get("newer_than", None),
synapse_client=synapse_client,
)
self.trace = trace_response["page"][0]["message"]


# TODO Add example usage to the docstring
@dataclass
@async_to_sync
class AgentSession(AgentSessionSynchronousProtocol, AsynchronousJob):
class AgentSession(AgentSessionSynchronousProtocol):
"""Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html)
Attributes:
Expand Down Expand Up @@ -150,7 +199,9 @@ async def start_async(
"""Starts an agent session.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The new AgentSession object.
Expand All @@ -171,7 +222,9 @@ async def get_async(
"""Gets an agent session.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The retrieved AgentSession object.
Expand All @@ -194,7 +247,9 @@ async def update_async(
Only updates to the access level are currently supported.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The updated AgentSession object.
Expand Down Expand Up @@ -223,45 +278,24 @@ async def prompt_async(
enable_trace: Whether to enable trace for the prompt.
print_response: Whether to print the response to the console.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
prompt_id = await self.send_job_async(
request_type=AGENT_CHAT_REQUEST,
session_id=self.id,
prompt=prompt,
enable_trace=enable_trace,
synapse_client=synapse_client,
)

answer_response = await self.get_job_async(
job_id=prompt_id,
request_type=AGENT_CHAT_REQUEST,
synapse_client=synapse_client,
agent_prompt = AgentPrompt(
prompt=prompt, session_id=self.id, enable_trace=enable_trace
)
response = answer_response["responseText"]

if enable_trace:
trace_response = await get_trace(
prompt_id=prompt_id,
newer_than=newer_than,
synapse_client=synapse_client,
)
trace = trace_response["page"][0]["message"]

self.chat_history.append(
AgentPrompt(
id=prompt_id,
prompt=prompt,
response=response,
trace=trace,
)
await agent_prompt.send_job_and_wait_async(
synapse_client=synapse_client, post_exchange_args={"newer_than": newer_than}
)
self.chat_history.append(agent_prompt)

if print_response:
print(f"PROMPT:\n{prompt}\n")
print(f"RESPONSE:\n{response}\n")
print(f"RESPONSE:\n{agent_prompt.response}\n")
if enable_trace:
print(f"TRACE:\n{trace}")
print(f"TRACE:\n{agent_prompt.trace}")


# TODO Add example usage to the docstring
Expand Down Expand Up @@ -328,7 +362,9 @@ async def register_async(
"""Registers an agent with the Synapse API. If agent exists, it will be retrieved.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The registered or existing Agent object.
Expand All @@ -348,7 +384,9 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent
"""Gets an existing agent.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The existing Agent object.
Expand Down Expand Up @@ -378,8 +416,9 @@ async def start_session_async(
access_level: The access level of the agent session.
Must be one of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA.
Defaults to PUBLICLY_ACCESSIBLE.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The new AgentSession object.
Expand All @@ -403,8 +442,9 @@ async def get_session_async(
Arguments:
session_id: The ID of the session to get.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
Returns:
The existing AgentSession object.
Expand Down Expand Up @@ -439,7 +479,9 @@ async def prompt_async(
print_response: Whether to print the response to the console.
session_id: The ID of the session to send the prompt to. If None, the current session will be used.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
# TODO: Iron this out. Make sure we cover all cases.
if session:
Expand Down
Loading

0 comments on commit fa888ff

Please sign in to comment.