Skip to content

Commit

Permalink
Rework agents workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed May 6, 2024
1 parent 202ac28 commit 883f49f
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 279 deletions.
19 changes: 11 additions & 8 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from draive.agents import (
Agent,
AgentException,
AgentFlow,
AgentScratchpad,
AgentState,
BaseAgent,
AgentsChat,
AgentsData,
AgentsDataAccess,
AgentsWorkflow,
agent,
)
from draive.conversation import (
Expand Down Expand Up @@ -133,9 +133,13 @@
"agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentScratchpad",
"AgentState",
"AgentsChat",
"AgentsData",
"AgentsDataAccess",
"AgentsWorkflow",
"agent",
"Agent",
"AgentException",
"allowing_early_exit",
"Argument",
"AsyncStream",
Expand All @@ -146,7 +150,6 @@
"AudioContent",
"AudioURLContent",
"auto_retry",
"BaseAgent",
"cache",
"conversation_completion",
"conversation_completion",
Expand Down
15 changes: 6 additions & 9 deletions src/draive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
# from draive.agents.pool import AgentPool, AgentPoolCoordinator
from draive.agents.abc import BaseAgent
from draive.agents.agent import Agent, agent
from draive.agents.errors import AgentException
from draive.agents.flow import AgentFlow
from draive.agents.state import AgentScratchpad, AgentState
from draive.agents.state import AgentsChat, AgentsData, AgentsDataAccess
from draive.agents.workflow import AgentsWorkflow

__all__ = [
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentState",
"AgentScratchpad",
"BaseAgent",
"AgentsChat",
"AgentsData",
"AgentsDataAccess",
"AgentsWorkflow",
]
52 changes: 0 additions & 52 deletions src/draive/agents/abc.py

This file was deleted.

128 changes: 58 additions & 70 deletions src/draive/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,95 @@
from collections.abc import Callable
from inspect import isfunction
from typing import final, overload
from uuid import uuid4
from typing import Protocol, final
from uuid import UUID, uuid4

from draive.agents.abc import BaseAgent
from draive.agents.errors import AgentException
from draive.agents.state import AgentState
from draive.agents.types import AgentInvocation
from draive.helpers import freeze
from draive.metrics import ArgumentsTrace
from draive.agents.state import AgentsChat, AgentsChatMessage, AgentsData
from draive.parameters import ParametrizedData
from draive.scope import ctx
from draive.types import MultimodalContent

__all__ = [
"agent",
"Agent",
"AgentInvocation",
]


class AgentInvocation[Data: ParametrizedData](Protocol):
async def __call__(
self,
chat: AgentsChat,
data: AgentsData[Data],
) -> MultimodalContent: ...


@final
class Agent[State: ParametrizedData](BaseAgent[State]):
class Agent[Data: ParametrizedData]:
def __init__(
self,
name: str,
description: str,
invoke: AgentInvocation[State],
role: str,
capabilities: str,
invocation: AgentInvocation[Data],
) -> None:
self.invoke: AgentInvocation[State] = invoke
super().__init__(
agent_id=uuid4().hex,
name=name,
description=description,
self.role: str = role
self.identifier: UUID = uuid4()
self.capabilities: str = capabilities
self._invocation: AgentInvocation[Data] = invocation
self.description: str = (
f"{self.role}:\n| ID: {self.identifier}\n| Capabilities: {self.capabilities}"
)

freeze(self)
def __eq__(
self,
other: object,
) -> bool:
if isinstance(other, Agent):
return self.identifier == other.identifier
else:
return False

def __hash__(self) -> int:
return hash(self.identifier)

def __str__(self) -> str:
return self.description

async def __call__(
self,
state: AgentState[State],
) -> MultimodalContent | None:
invocation_id: str = uuid4().hex
chat: AgentsChat,
data: AgentsData[Data],
) -> AgentsChatMessage:
with ctx.nested(
f"Agent|{self.name}",
metrics=[
ArgumentsTrace.of(
agent_id=self.agent_id,
invocation_id=invocation_id,
state=state,
)
],
f"Agent|{self.role}|{self.identifier}",
):
try:
return await self.invoke(state)
return AgentsChatMessage(
author=self.role,
content=await self._invocation(chat, data),
)

except Exception as exc:
raise AgentException(
"Agent invocation %s of %s failed due to an error: %s",
invocation_id,
self.agent_id,
"Agent invocation of %s failed due to an error: %s",
self.identifier,
exc,
) from exc


@overload
def agent[State: ParametrizedData](
invoke: AgentInvocation[State],
/,
) -> Agent[State]: ...


@overload
def agent[State: ParametrizedData](
*,
name: str,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


@overload
def agent[State: ParametrizedData](
*,
description: str,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


def agent[State: ParametrizedData](
invoke: AgentInvocation[State] | None = None,
def agent[Data: ParametrizedData](
*,
name: str | None = None,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]:
role: str,
capabilities: str,
) -> Callable[[AgentInvocation[Data]], Agent[Data]]:
def wrap(
invoke: AgentInvocation[State],
) -> Agent[State]:
assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101
return Agent[State](
name=name or invoke.__qualname__,
description=description or "",
invoke=invoke,
invoke: AgentInvocation[Data],
) -> Agent[Data]:
assert isfunction(invoke), "Agent has to be defined from a function" # nosec: B101
return Agent[Data](
role=role,
capabilities=capabilities,
invocation=invoke,
)

if invoke := invoke:
return wrap(invoke=invoke)
else:
return wrap
return wrap
96 changes: 96 additions & 0 deletions src/draive/agents/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Protocol
from uuid import UUID

from draive.agents.agent import Agent
from draive.agents.errors import AgentException
from draive.agents.state import AgentsChat, AgentsData
from draive.generation import generate_model
from draive.parameters import Field, ParametrizedData
from draive.scope import ctx
from draive.tools import Toolbox
from draive.types import Model, MultimodalContent, State

__all__ = [
"AgentDisposition",
"AgentsCoordinator",
"basic_agent_coordinator",
]


class AgentDisposition[Data: ParametrizedData](State):
recipient: Agent[Data]
message: MultimodalContent


class AgentsCoordinator[Data: ParametrizedData](Protocol):
async def __call__(
self,
chat: AgentsChat,
data: AgentsData[Data],
agents: frozenset[Agent[Data]],
) -> AgentDisposition[Data]: ...


class Disposition(Model):
recipient: UUID = Field(
description="Exact ID of the employee which should take over the task",
)
message: str = Field(
description="Observation about current progress and"
" proposal of task for the chosen employee",
)


INSTRUCTION: str = """\
You are a Coordinator managing work of a group of employees trying to achieve a common goal.
Your task is to verify current progress and propose a next step in order \
to complete the goal using available employees. Examine the conversation and current progress \
to prepare the most suitable next step. You have to choose ID associated with given employee \
to ask that particular employee for continuation but you are fully responsible for the final result. \
The instructions you propose should be step by step, small and easily achievable tasks to avoid \
misunderstanding and to allow continuous tracking of the progress until it is fully done.
Available employees:
---
{agents}
---
Provide only a single disposition without any additional comments or elements.
""" # noqa: E501


async def basic_agent_coordinator[Data: ParametrizedData](
chat: AgentsChat,
data: AgentsData[Data],
agents: frozenset[Agent[Data]],
) -> AgentDisposition[Data]:
disposition: Disposition = await generate_model(
Disposition,
instruction=INSTRUCTION.format(
agents="\n---\n".join(agent.description for agent in agents)
),
input=(
f"CONVERSATION:\n{chat.as_str()}",
"PROGRESS:\n---\n"
+ "\n---\n".join(
f"{key}:\n{value}" for key, value in (await data.current_contents).items()
),
),
tools=Toolbox(data.read_tool()),
# TODO: add examples
)
ctx.log_debug("Agent %s disposition: %s", disposition.recipient, disposition.message)
selected_agent: Agent[Data]
for agent in agents:
if agent.identifier == disposition.recipient:
selected_agent = agent
break
else:
continue
else:
raise AgentException("Selected invalid agent")

return AgentDisposition(
recipient=selected_agent,
message=disposition.message,
)
Loading

0 comments on commit 883f49f

Please sign in to comment.