Skip to content

Commit

Permalink
Add agents abstraction
Browse files Browse the repository at this point in the history
Rework agents abstraction
  • Loading branch information
KaQuMiQ authored Apr 24, 2024
1 parent ff2dee2 commit 36e868f
Show file tree
Hide file tree
Showing 18 changed files with 461 additions and 43 deletions.
20 changes: 20 additions & 0 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from draive.agents import (
Agent,
AgentException,
AgentFlow,
AgentState,
BaseAgent,
agent,
)
from draive.conversation import (
Conversation,
ConversationCompletion,
Expand Down Expand Up @@ -105,6 +113,7 @@
VideoBase64Content,
VideoContent,
VideoURLContent,
merge_multimodal_content,
)
from draive.utils import (
AsyncStream,
Expand All @@ -119,12 +128,22 @@
__all__ = [
"AsyncStream",
"AsyncStreamTask",
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentState",
"allowing_early_exit",
"Argument",
"AsyncStream",
"AsyncStreamTask",
"AudioBase64Content",
"AudioContent",
"AudioURLContent",
"auto_retry",
"BaseAgent",
"cache",
"conversation_completion",
"conversation_completion",
Expand Down Expand Up @@ -165,6 +184,7 @@
"LMMCompletionStreamingUpdate",
"load_env",
"Memory",
"merge_multimodal_content",
"Metric",
"metrics_log_reporter",
"MetricsTrace",
Expand Down
17 changes: 17 additions & 0 deletions src/draive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# from draive.agents.pool import AgentPool, AgentPoolCoordinator
from draive.agents.abc import BaseAgent
from draive.agents.agent import Agent, agent
from draive.agents.errors import AgentException
from draive.agents.flow import AgentFlow
from draive.agents.state import AgentState

__all__ = [
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentState",
"BaseAgent",
]
38 changes: 38 additions & 0 deletions src/draive/agents/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod

from draive.agents.state import AgentState
from draive.parameters import ParametrizedData

__all__ = [
"BaseAgent",
]


class BaseAgent[State: ParametrizedData](ABC):
def __init__(
self,
agent_id: str,
name: str,
description: str,
) -> None:
self.agent_id: str = agent_id
self.name: str = name
self.description: str = description

def __eq__(
self,
other: object,
) -> bool:
if isinstance(other, BaseAgent):
return self.agent_id == other.agent_id
else:
return False

def __hash__(self) -> int:
return hash(self.agent_id)

@abstractmethod
async def __call__(
self,
state: AgentState[State],
) -> None: ...
106 changes: 106 additions & 0 deletions src/draive/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from collections.abc import Callable
from inspect import isfunction
from typing import final, overload
from uuid import uuid4

from draive.agents.abc import BaseAgent
from draive.agents.errors import AgentException
from draive.agents.state import AgentState
from draive.agents.types import AgentInvocation
from draive.helpers import freeze
from draive.metrics import ArgumentsTrace
from draive.parameters import ParametrizedData
from draive.scope import ctx

__all__ = [
"agent",
"Agent",
]


@final
class Agent[State: ParametrizedData](BaseAgent[State]):
def __init__(
self,
name: str,
description: str,
invoke: AgentInvocation[State],
) -> None:
self.invoke: AgentInvocation[State] = invoke
super().__init__(
agent_id=uuid4().hex,
name=name,
description=description,
)

freeze(self)

async def __call__(
self,
state: AgentState[State],
) -> None:
invocation_id: str = uuid4().hex
with ctx.nested(
f"Agent|{self.name}",
metrics=[
ArgumentsTrace.of(
agent_id=self.agent_id,
invocation_id=invocation_id,
state=state,
)
],
):
try:
return await self.invoke(state)

except Exception as exc:
raise AgentException(
"Agent invocation %s of %s failed due to an error: %s",
invocation_id,
self.agent_id,
exc,
) from exc


@overload
def agent[State: ParametrizedData](
invoke: AgentInvocation[State],
/,
) -> Agent[State]: ...


@overload
def agent[State: ParametrizedData](
*,
name: str,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


@overload
def agent[State: ParametrizedData](
*,
description: str,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


def agent[State: ParametrizedData](
invoke: AgentInvocation[State] | None = None,
*,
name: str | None = None,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]:
def wrap(
invoke: AgentInvocation[State],
) -> Agent[State]:
assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101
return Agent[State](
name=name or invoke.__qualname__,
description=description or "",
invoke=invoke,
)

if invoke := invoke:
return wrap(invoke=invoke)
else:
return wrap
7 changes: 7 additions & 0 deletions src/draive/agents/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = [
"AgentException",
]


class AgentException(Exception):
pass
47 changes: 47 additions & 0 deletions src/draive/agents/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from asyncio import gather
from typing import final
from uuid import uuid4

from draive.agents.abc import BaseAgent
from draive.agents.state import AgentState
from draive.helpers import freeze
from draive.parameters import ParametrizedData

__all__ = [
"AgentFlow",
]


@final
class AgentFlow[State: ParametrizedData](BaseAgent[State]):
def __init__(
self,
*agents: tuple[BaseAgent[State], ...] | BaseAgent[State],
name: str,
description: str,
) -> None:
super().__init__(
agent_id=uuid4().hex,
name=name,
description=description,
)
self.agents: tuple[tuple[BaseAgent[State], ...] | BaseAgent[State], ...] = agents

freeze(self)

async def __call__(
self,
state: AgentState[State],
) -> None:
for agent in self.agents:
match agent:
case [*agents]:
await gather(
*[agent(state) for agent in agents],
)

# case [agent]:
# await agent(state)

case agent:
await agent(state)
63 changes: 63 additions & 0 deletions src/draive/agents/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from asyncio import Lock
from collections.abc import Callable
from typing import Any

from draive.parameters import ParametrizedData
from draive.types import MultimodalContent, MultimodalContentItem, merge_multimodal_content

__all__ = [
"AgentState",
]


class AgentState[State: ParametrizedData]:
def __init__(
self,
initial: State,
scratchpad: MultimodalContent | None = None,
) -> None:
self._lock: Lock = Lock()
self._current: State = initial
self._scratchpad: tuple[MultimodalContentItem, ...]
match scratchpad:
case None:
self._scratchpad = ()
case [*items]:
self._scratchpad = tuple(items)
case item:
self._scratchpad = (item,)

@property
async def current(self) -> State:
async with self._lock:
return self._current

@property
async def scratchpad(self) -> MultimodalContent:
async with self._lock:
return self._scratchpad

async def extend_scratchpad(
self,
content: MultimodalContent,
) -> MultimodalContent:
async with self._lock:
self._scratchpad = merge_multimodal_content(self._scratchpad, content)
return self._scratchpad

async def apply(
self,
patch: Callable[[State], State],
) -> State:
async with self._lock:
self._current = patch(self._current)
return self._current

# TODO: find a way to generate signature Based on ParametrizedData
async def update(
self,
**kwargs: Any,
) -> State:
async with self._lock:
self._current = self._current.updated(**kwargs)
return self._current
15 changes: 15 additions & 0 deletions src/draive/agents/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Protocol

from draive.agents.state import AgentState
from draive.parameters import ParametrizedData

__all__ = [
"AgentInvocation",
]


class AgentInvocation[State: ParametrizedData](Protocol):
async def __call__(
self,
state: AgentState[State],
) -> None: ...
2 changes: 1 addition & 1 deletion src/draive/metrics/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __add__(self, other: Self) -> Self:
name="ExceptionGroup",
exception=BaseExceptionGroup(
"Multiple errors",
(*self.exception.exceptions, other.exception), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
(*self.exception.exceptions, other.exception), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType]
),
)
else:
Expand Down
Loading

0 comments on commit 36e868f

Please sign in to comment.