Skip to content

Commit

Permalink
Add agents abstraction
Browse files Browse the repository at this point in the history
Rework agents abstraction
  • Loading branch information
KaQuMiQ committed Apr 23, 2024
1 parent 82e72b4 commit 9a04cc7
Show file tree
Hide file tree
Showing 21 changed files with 470 additions and 50 deletions.
10 changes: 5 additions & 5 deletions constraints
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o constraints --all-extras
# uv --no-cache pip compile pyproject.toml -o constraints --all-extras
annotated-types==0.6.0
# via pydantic
anyio==4.3.0
Expand Down Expand Up @@ -41,7 +41,7 @@ mistralai==0.1.8
nodeenv==1.8.0
# via pyright
numpy==1.26.4
openai==1.19.0
openai==1.23.1
orjson==3.10.1
# via mistralai
packaging==24.0
Expand All @@ -58,7 +58,7 @@ pydantic-core==2.18.1
# via pydantic
pygments==2.17.2
# via rich
pyright==1.1.358
pyright==1.1.359
pytest==7.4.4
# via
# pytest-asyncio
Expand All @@ -67,13 +67,13 @@ pytest-asyncio==0.23.6
pytest-cov==4.1.0
pyyaml==6.0.1
# via bandit
regex==2023.12.25
regex==2024.4.16
# via tiktoken
requests==2.31.0
# via tiktoken
rich==13.7.1
# via bandit
ruff==0.3.7
ruff==0.4.0
setuptools==69.5.1
# via nodeenv
sniffio==1.3.1
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Repository = "https://github.com/miquido/draive.git"

[project.optional-dependencies]
dev = [
"ruff~=0.3.0",
"ruff~=0.4.0",
"pyright~=1.1",
"bandit~=1.7",
"pytest~=7.4",
Expand Down
22 changes: 22 additions & 0 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from draive.agents import (
Agent,
AgentException,
AgentFlow,
AgentState,
BaseAgent,
agent,
)
from draive.conversation import (
Conversation,
ConversationCompletion,
Expand Down Expand Up @@ -106,8 +114,11 @@
VideoBase64Content,
VideoContent,
VideoURLContent,
merge_multimodal_content,
)
from draive.utils import (
AsyncStream,
AsyncStreamTask,
allowing_early_exit,
auto_retry,
cache,
Expand All @@ -116,12 +127,22 @@
)

__all__ = [
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentState",
"allowing_early_exit",
"Argument",
"AsyncStream",
"AsyncStreamTask",
"AudioBase64Content",
"AudioContent",
"AudioURLContent",
"auto_retry",
"BaseAgent",
"cache",
"conversation_completion",
"conversation_completion",
Expand Down Expand Up @@ -162,6 +183,7 @@
"LMMCompletionStreamingUpdate",
"load_env",
"Memory",
"merge_multimodal_content",
"Metric",
"metrics_log_reporter",
"MetricsTrace",
Expand Down
17 changes: 17 additions & 0 deletions src/draive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# from draive.agents.pool import AgentPool, AgentPoolCoordinator
from draive.agents.abc import BaseAgent
from draive.agents.agent import Agent, agent
from draive.agents.errors import AgentException
from draive.agents.flow import AgentFlow
from draive.agents.state import AgentState

__all__ = [
"agent",
"agent",
"Agent",
"Agent",
"AgentException",
"AgentFlow",
"AgentState",
"BaseAgent",
]
38 changes: 38 additions & 0 deletions src/draive/agents/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod

from draive.agents.state import AgentState
from draive.parameters import ParametrizedData

__all__ = [
"BaseAgent",
]


class BaseAgent[State: ParametrizedData](ABC):
def __init__(
self,
agent_id: str,
name: str,
description: str,
) -> None:
self.agent_id: str = agent_id
self.name: str = name
self.description: str = description

def __eq__(
self,
other: object,
) -> bool:
if isinstance(other, BaseAgent):
return self.agent_id == other.agent_id
else:
return False

def __hash__(self) -> int:
return hash(self.agent_id)

@abstractmethod
async def __call__(
self,
state: AgentState[State],
) -> None: ...
106 changes: 106 additions & 0 deletions src/draive/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from collections.abc import Callable
from inspect import isfunction
from typing import final, overload
from uuid import uuid4

from draive.agents.abc import BaseAgent
from draive.agents.errors import AgentException
from draive.agents.state import AgentState
from draive.agents.types import AgentInvocation
from draive.helpers import freeze
from draive.metrics import ArgumentsTrace
from draive.parameters import ParametrizedData
from draive.scope import ctx

__all__ = [
"agent",
"Agent",
]


@final
class Agent[State: ParametrizedData](BaseAgent[State]):
def __init__(
self,
name: str,
description: str,
invoke: AgentInvocation[State],
) -> None:
self.invoke: AgentInvocation[State] = invoke
super().__init__(
agent_id=uuid4().hex,
name=name,
description=description,
)

freeze(self)

async def __call__(
self,
state: AgentState[State],
) -> None:
invocation_id: str = uuid4().hex
with ctx.nested(
f"Agent|{self.name}",
metrics=[
ArgumentsTrace.of(
agent_id=self.agent_id,
invocation_id=invocation_id,
state=state,
)
],
):
try:
return await self.invoke(state)

except Exception as exc:
raise AgentException(
"Agent invocation %s of %s failed due to an error: %s",
invocation_id,
self.agent_id,
exc,
) from exc


@overload
def agent[State: ParametrizedData](
invoke: AgentInvocation[State],
/,
) -> Agent[State]: ...


@overload
def agent[State: ParametrizedData](
*,
name: str,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


@overload
def agent[State: ParametrizedData](
*,
description: str,
) -> Callable[[AgentInvocation[State]], Agent[State]]: ...


def agent[State: ParametrizedData](
invoke: AgentInvocation[State] | None = None,
*,
name: str | None = None,
description: str | None = None,
) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]:
def wrap(
invoke: AgentInvocation[State],
) -> Agent[State]:
assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101
return Agent[State](
name=name or invoke.__qualname__,
description=description or "",
invoke=invoke,
)

if invoke := invoke:
return wrap(invoke=invoke)
else:
return wrap
7 changes: 7 additions & 0 deletions src/draive/agents/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = [
"AgentException",
]


class AgentException(Exception):
pass
47 changes: 47 additions & 0 deletions src/draive/agents/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from asyncio import gather
from typing import final
from uuid import uuid4

from draive.agents.abc import BaseAgent
from draive.agents.state import AgentState
from draive.helpers import freeze
from draive.parameters import ParametrizedData

__all__ = [
"AgentFlow",
]


@final
class AgentFlow[State: ParametrizedData](BaseAgent[State]):
def __init__(
self,
*agents: tuple[BaseAgent[State], ...] | BaseAgent[State],
name: str,
description: str,
) -> None:
super().__init__(
agent_id=uuid4().hex,
name=name,
description=description,
)
self.agents: tuple[tuple[BaseAgent[State], ...] | BaseAgent[State], ...] = agents

freeze(self)

async def __call__(
self,
state: AgentState[State],
) -> None:
for agent in self.agents:
match agent:
case [*agents]:
await gather(
*[agent(state) for agent in agents],
)

# case [agent]:
# await agent(state)

case agent:
await agent(state)
63 changes: 63 additions & 0 deletions src/draive/agents/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from asyncio import Lock
from collections.abc import Callable
from typing import Any

from draive.parameters import ParametrizedData
from draive.types import MultimodalContent, MultimodalContentItem, merge_multimodal_content

__all__ = [
"AgentState",
]


class AgentState[State: ParametrizedData]:
def __init__(
self,
initial: State,
scratchpad: MultimodalContent | None = None,
) -> None:
self._lock: Lock = Lock()
self._current: State = initial
self._scratchpad: tuple[MultimodalContentItem, ...]
match scratchpad:
case None:
self._scratchpad = ()
case [*items]:
self._scratchpad = tuple(items)
case item:
self._scratchpad = (item,)

@property
async def current(self) -> State:
async with self._lock:
return self._current

@property
async def scratchpad(self) -> MultimodalContent:
async with self._lock:
return self._scratchpad

async def extend_scratchpad(
self,
content: MultimodalContent,
) -> MultimodalContent:
async with self._lock:
self._scratchpad = merge_multimodal_content(self._scratchpad, content)
return self._scratchpad

async def apply(
self,
patch: Callable[[State], State],
) -> State:
async with self._lock:
self._current = patch(self._current)
return self._current

# TODO: find a way to generate signature Based on ParametrizedData
async def update(
self,
**kwargs: Any,
) -> State:
async with self._lock:
self._current = self._current.updated(**kwargs)
return self._current
Loading

0 comments on commit 9a04cc7

Please sign in to comment.