-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
420 additions
and
279 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,15 @@ | ||
# from draive.agents.pool import AgentPool, AgentPoolCoordinator | ||
from draive.agents.abc import BaseAgent | ||
from draive.agents.agent import Agent, agent | ||
from draive.agents.errors import AgentException | ||
from draive.agents.flow import AgentFlow | ||
from draive.agents.state import AgentScratchpad, AgentState | ||
from draive.agents.state import AgentsChat, AgentsData, AgentsDataAccess | ||
from draive.agents.workflow import AgentsWorkflow | ||
|
||
__all__ = [ | ||
"agent", | ||
"agent", | ||
"Agent", | ||
"Agent", | ||
"AgentException", | ||
"AgentFlow", | ||
"AgentState", | ||
"AgentScratchpad", | ||
"BaseAgent", | ||
"AgentsChat", | ||
"AgentsData", | ||
"AgentsDataAccess", | ||
"AgentsWorkflow", | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,95 @@ | ||
from collections.abc import Callable | ||
from inspect import isfunction | ||
from typing import final, overload | ||
from uuid import uuid4 | ||
from typing import Protocol, final | ||
from uuid import UUID, uuid4 | ||
|
||
from draive.agents.abc import BaseAgent | ||
from draive.agents.errors import AgentException | ||
from draive.agents.state import AgentState | ||
from draive.agents.types import AgentInvocation | ||
from draive.helpers import freeze | ||
from draive.metrics import ArgumentsTrace | ||
from draive.agents.state import AgentsChat, AgentsChatMessage, AgentsData | ||
from draive.parameters import ParametrizedData | ||
from draive.scope import ctx | ||
from draive.types import MultimodalContent | ||
|
||
__all__ = [ | ||
"agent", | ||
"Agent", | ||
"AgentInvocation", | ||
] | ||
|
||
|
||
class AgentInvocation[Data: ParametrizedData](Protocol): | ||
async def __call__( | ||
self, | ||
chat: AgentsChat, | ||
data: AgentsData[Data], | ||
) -> MultimodalContent: ... | ||
|
||
|
||
@final | ||
class Agent[State: ParametrizedData](BaseAgent[State]): | ||
class Agent[Data: ParametrizedData]: | ||
def __init__( | ||
self, | ||
name: str, | ||
description: str, | ||
invoke: AgentInvocation[State], | ||
role: str, | ||
capabilities: str, | ||
invocation: AgentInvocation[Data], | ||
) -> None: | ||
self.invoke: AgentInvocation[State] = invoke | ||
super().__init__( | ||
agent_id=uuid4().hex, | ||
name=name, | ||
description=description, | ||
self.role: str = role | ||
self.identifier: UUID = uuid4() | ||
self.capabilities: str = capabilities | ||
self._invocation: AgentInvocation[Data] = invocation | ||
self.description: str = ( | ||
f"{self.role}:\n| ID: {self.identifier}\n| Capabilities: {self.capabilities}" | ||
) | ||
|
||
freeze(self) | ||
def __eq__( | ||
self, | ||
other: object, | ||
) -> bool: | ||
if isinstance(other, Agent): | ||
return self.identifier == other.identifier | ||
else: | ||
return False | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.identifier) | ||
|
||
def __str__(self) -> str: | ||
return self.description | ||
|
||
async def __call__( | ||
self, | ||
state: AgentState[State], | ||
) -> MultimodalContent | None: | ||
invocation_id: str = uuid4().hex | ||
chat: AgentsChat, | ||
data: AgentsData[Data], | ||
) -> AgentsChatMessage: | ||
with ctx.nested( | ||
f"Agent|{self.name}", | ||
metrics=[ | ||
ArgumentsTrace.of( | ||
agent_id=self.agent_id, | ||
invocation_id=invocation_id, | ||
state=state, | ||
) | ||
], | ||
f"Agent|{self.role}|{self.identifier}", | ||
): | ||
try: | ||
return await self.invoke(state) | ||
return AgentsChatMessage( | ||
author=self.role, | ||
content=await self._invocation(chat, data), | ||
) | ||
|
||
except Exception as exc: | ||
raise AgentException( | ||
"Agent invocation %s of %s failed due to an error: %s", | ||
invocation_id, | ||
self.agent_id, | ||
"Agent invocation of %s failed due to an error: %s", | ||
self.identifier, | ||
exc, | ||
) from exc | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
invoke: AgentInvocation[State], | ||
/, | ||
) -> Agent[State]: ... | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
*, | ||
name: str, | ||
description: str | None = None, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]]: ... | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
*, | ||
description: str, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]]: ... | ||
|
||
|
||
def agent[State: ParametrizedData]( | ||
invoke: AgentInvocation[State] | None = None, | ||
def agent[Data: ParametrizedData]( | ||
*, | ||
name: str | None = None, | ||
description: str | None = None, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]: | ||
role: str, | ||
capabilities: str, | ||
) -> Callable[[AgentInvocation[Data]], Agent[Data]]: | ||
def wrap( | ||
invoke: AgentInvocation[State], | ||
) -> Agent[State]: | ||
assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101 | ||
return Agent[State]( | ||
name=name or invoke.__qualname__, | ||
description=description or "", | ||
invoke=invoke, | ||
invoke: AgentInvocation[Data], | ||
) -> Agent[Data]: | ||
assert isfunction(invoke), "Agent has to be defined from a function" # nosec: B101 | ||
return Agent[Data]( | ||
role=role, | ||
capabilities=capabilities, | ||
invocation=invoke, | ||
) | ||
|
||
if invoke := invoke: | ||
return wrap(invoke=invoke) | ||
else: | ||
return wrap | ||
return wrap |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Protocol | ||
from uuid import UUID | ||
|
||
from draive.agents.agent import Agent | ||
from draive.agents.errors import AgentException | ||
from draive.agents.state import AgentsChat, AgentsData | ||
from draive.generation import generate_model | ||
from draive.parameters import Field, ParametrizedData | ||
from draive.scope import ctx | ||
from draive.tools import Toolbox | ||
from draive.types import Model, MultimodalContent, State | ||
|
||
__all__ = [ | ||
"AgentDisposition", | ||
"AgentsCoordinator", | ||
"basic_agent_coordinator", | ||
] | ||
|
||
|
||
class AgentDisposition[Data: ParametrizedData](State): | ||
recipient: Agent[Data] | ||
message: MultimodalContent | ||
|
||
|
||
class AgentsCoordinator[Data: ParametrizedData](Protocol): | ||
async def __call__( | ||
self, | ||
chat: AgentsChat, | ||
data: AgentsData[Data], | ||
agents: frozenset[Agent[Data]], | ||
) -> AgentDisposition[Data]: ... | ||
|
||
|
||
class Disposition(Model): | ||
recipient: UUID = Field( | ||
description="Exact ID of the employee which should take over the task", | ||
) | ||
message: str = Field( | ||
description="Observation about current progress and" | ||
" proposal of task for the chosen employee", | ||
) | ||
|
||
|
||
INSTRUCTION: str = """\ | ||
You are a Coordinator managing work of a group of employees trying to achieve a common goal. | ||
Your task is to verify current progress and propose a next step in order \ | ||
to complete the goal using available employees. Examine the conversation and current progress \ | ||
to prepare the most suitable next step. You have to choose ID associated with given employee \ | ||
to ask that particular employee for continuation but you are fully responsible for the final result. \ | ||
The instructions you propose should be step by step, small and easily achievable tasks to avoid \ | ||
misunderstanding and to allow continuous tracking of the progress until it is fully done. | ||
Available employees: | ||
--- | ||
{agents} | ||
--- | ||
Provide only a single disposition without any additional comments or elements. | ||
""" # noqa: E501 | ||
|
||
|
||
async def basic_agent_coordinator[Data: ParametrizedData]( | ||
chat: AgentsChat, | ||
data: AgentsData[Data], | ||
agents: frozenset[Agent[Data]], | ||
) -> AgentDisposition[Data]: | ||
disposition: Disposition = await generate_model( | ||
Disposition, | ||
instruction=INSTRUCTION.format( | ||
agents="\n---\n".join(agent.description for agent in agents) | ||
), | ||
input=( | ||
f"CONVERSATION:\n{chat.as_str()}", | ||
"PROGRESS:\n---\n" | ||
+ "\n---\n".join( | ||
f"{key}:\n{value}" for key, value in (await data.current_contents).items() | ||
), | ||
), | ||
tools=Toolbox(data.read_tool()), | ||
# TODO: add examples | ||
) | ||
ctx.log_debug("Agent %s disposition: %s", disposition.recipient, disposition.message) | ||
selected_agent: Agent[Data] | ||
for agent in agents: | ||
if agent.identifier == disposition.recipient: | ||
selected_agent = agent | ||
break | ||
else: | ||
continue | ||
else: | ||
raise AgentException("Selected invalid agent") | ||
|
||
return AgentDisposition( | ||
recipient=selected_agent, | ||
message=disposition.message, | ||
) |
Oops, something went wrong.