-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework agents abstraction
- Loading branch information
Showing
21 changed files
with
470 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# from draive.agents.pool import AgentPool, AgentPoolCoordinator | ||
from draive.agents.abc import BaseAgent | ||
from draive.agents.agent import Agent, agent | ||
from draive.agents.errors import AgentException | ||
from draive.agents.flow import AgentFlow | ||
from draive.agents.state import AgentState | ||
|
||
__all__ = [ | ||
"agent", | ||
"agent", | ||
"Agent", | ||
"Agent", | ||
"AgentException", | ||
"AgentFlow", | ||
"AgentState", | ||
"BaseAgent", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from draive.agents.state import AgentState | ||
from draive.parameters import ParametrizedData | ||
|
||
__all__ = [ | ||
"BaseAgent", | ||
] | ||
|
||
|
||
class BaseAgent[State: ParametrizedData](ABC): | ||
def __init__( | ||
self, | ||
agent_id: str, | ||
name: str, | ||
description: str, | ||
) -> None: | ||
self.agent_id: str = agent_id | ||
self.name: str = name | ||
self.description: str = description | ||
|
||
def __eq__( | ||
self, | ||
other: object, | ||
) -> bool: | ||
if isinstance(other, BaseAgent): | ||
return self.agent_id == other.agent_id | ||
else: | ||
return False | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.agent_id) | ||
|
||
@abstractmethod | ||
async def __call__( | ||
self, | ||
state: AgentState[State], | ||
) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from collections.abc import Callable | ||
from inspect import isfunction | ||
from typing import final, overload | ||
from uuid import uuid4 | ||
|
||
from draive.agents.abc import BaseAgent | ||
from draive.agents.errors import AgentException | ||
from draive.agents.state import AgentState | ||
from draive.agents.types import AgentInvocation | ||
from draive.helpers import freeze | ||
from draive.metrics import ArgumentsTrace | ||
from draive.parameters import ParametrizedData | ||
from draive.scope import ctx | ||
|
||
__all__ = [ | ||
"agent", | ||
"Agent", | ||
] | ||
|
||
|
||
@final | ||
class Agent[State: ParametrizedData](BaseAgent[State]): | ||
def __init__( | ||
self, | ||
name: str, | ||
description: str, | ||
invoke: AgentInvocation[State], | ||
) -> None: | ||
self.invoke: AgentInvocation[State] = invoke | ||
super().__init__( | ||
agent_id=uuid4().hex, | ||
name=name, | ||
description=description, | ||
) | ||
|
||
freeze(self) | ||
|
||
async def __call__( | ||
self, | ||
state: AgentState[State], | ||
) -> None: | ||
invocation_id: str = uuid4().hex | ||
with ctx.nested( | ||
f"Agent|{self.name}", | ||
metrics=[ | ||
ArgumentsTrace.of( | ||
agent_id=self.agent_id, | ||
invocation_id=invocation_id, | ||
state=state, | ||
) | ||
], | ||
): | ||
try: | ||
return await self.invoke(state) | ||
|
||
except Exception as exc: | ||
raise AgentException( | ||
"Agent invocation %s of %s failed due to an error: %s", | ||
invocation_id, | ||
self.agent_id, | ||
exc, | ||
) from exc | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
invoke: AgentInvocation[State], | ||
/, | ||
) -> Agent[State]: ... | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
*, | ||
name: str, | ||
description: str | None = None, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]]: ... | ||
|
||
|
||
@overload | ||
def agent[State: ParametrizedData]( | ||
*, | ||
description: str, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]]: ... | ||
|
||
|
||
def agent[State: ParametrizedData]( | ||
invoke: AgentInvocation[State] | None = None, | ||
*, | ||
name: str | None = None, | ||
description: str | None = None, | ||
) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]: | ||
def wrap( | ||
invoke: AgentInvocation[State], | ||
) -> Agent[State]: | ||
assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101 | ||
return Agent[State]( | ||
name=name or invoke.__qualname__, | ||
description=description or "", | ||
invoke=invoke, | ||
) | ||
|
||
if invoke := invoke: | ||
return wrap(invoke=invoke) | ||
else: | ||
return wrap |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
__all__ = [ | ||
"AgentException", | ||
] | ||
|
||
|
||
class AgentException(Exception): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from asyncio import gather | ||
from typing import final | ||
from uuid import uuid4 | ||
|
||
from draive.agents.abc import BaseAgent | ||
from draive.agents.state import AgentState | ||
from draive.helpers import freeze | ||
from draive.parameters import ParametrizedData | ||
|
||
__all__ = [ | ||
"AgentFlow", | ||
] | ||
|
||
|
||
@final | ||
class AgentFlow[State: ParametrizedData](BaseAgent[State]): | ||
def __init__( | ||
self, | ||
*agents: tuple[BaseAgent[State], ...] | BaseAgent[State], | ||
name: str, | ||
description: str, | ||
) -> None: | ||
super().__init__( | ||
agent_id=uuid4().hex, | ||
name=name, | ||
description=description, | ||
) | ||
self.agents: tuple[tuple[BaseAgent[State], ...] | BaseAgent[State], ...] = agents | ||
|
||
freeze(self) | ||
|
||
async def __call__( | ||
self, | ||
state: AgentState[State], | ||
) -> None: | ||
for agent in self.agents: | ||
match agent: | ||
case [*agents]: | ||
await gather( | ||
*[agent(state) for agent in agents], | ||
) | ||
|
||
# case [agent]: | ||
# await agent(state) | ||
|
||
case agent: | ||
await agent(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from asyncio import Lock | ||
from collections.abc import Callable | ||
from typing import Any | ||
|
||
from draive.parameters import ParametrizedData | ||
from draive.types import MultimodalContent, MultimodalContentItem, merge_multimodal_content | ||
|
||
__all__ = [ | ||
"AgentState", | ||
] | ||
|
||
|
||
class AgentState[State: ParametrizedData]: | ||
def __init__( | ||
self, | ||
initial: State, | ||
scratchpad: MultimodalContent | None = None, | ||
) -> None: | ||
self._lock: Lock = Lock() | ||
self._current: State = initial | ||
self._scratchpad: tuple[MultimodalContentItem, ...] | ||
match scratchpad: | ||
case None: | ||
self._scratchpad = () | ||
case [*items]: | ||
self._scratchpad = tuple(items) | ||
case item: | ||
self._scratchpad = (item,) | ||
|
||
@property | ||
async def current(self) -> State: | ||
async with self._lock: | ||
return self._current | ||
|
||
@property | ||
async def scratchpad(self) -> MultimodalContent: | ||
async with self._lock: | ||
return self._scratchpad | ||
|
||
async def extend_scratchpad( | ||
self, | ||
content: MultimodalContent, | ||
) -> MultimodalContent: | ||
async with self._lock: | ||
self._scratchpad = merge_multimodal_content(self._scratchpad, content) | ||
return self._scratchpad | ||
|
||
async def apply( | ||
self, | ||
patch: Callable[[State], State], | ||
) -> State: | ||
async with self._lock: | ||
self._current = patch(self._current) | ||
return self._current | ||
|
||
# TODO: find a way to generate signature Based on ParametrizedData | ||
async def update( | ||
self, | ||
**kwargs: Any, | ||
) -> State: | ||
async with self._lock: | ||
self._current = self._current.updated(**kwargs) | ||
return self._current |
Oops, something went wrong.