From 9a04cc795d376384314c6ff45b28a439d4ad1f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Kali=C5=84ski?= Date: Thu, 18 Apr 2024 10:44:45 +0200 Subject: [PATCH] Add agents abstraction Rework agents abstraction --- constraints | 10 +-- pyproject.toml | 2 +- src/draive/__init__.py | 22 +++++ src/draive/agents/__init__.py | 17 ++++ src/draive/agents/abc.py | 38 +++++++++ src/draive/agents/agent.py | 106 +++++++++++++++++++++++++ src/draive/agents/errors.py | 7 ++ src/draive/agents/flow.py | 47 +++++++++++ src/draive/agents/state.py | 63 +++++++++++++++ src/draive/agents/types.py | 15 ++++ src/draive/metrics/function.py | 2 +- src/draive/parameters/specification.py | 26 +++++- src/draive/parameters/validation.py | 98 +++++++++++++++++------ src/draive/types/__init__.py | 8 +- src/draive/types/audio.py | 3 + src/draive/types/images.py | 3 + src/draive/types/model.py | 12 +-- src/draive/types/multimodal.py | 18 ++++- src/draive/types/video.py | 3 + src/draive/utils/stream.py | 2 +- tests/test_model.py | 18 +++-- 21 files changed, 470 insertions(+), 50 deletions(-) create mode 100644 src/draive/agents/__init__.py create mode 100644 src/draive/agents/abc.py create mode 100644 src/draive/agents/agent.py create mode 100644 src/draive/agents/errors.py create mode 100644 src/draive/agents/flow.py create mode 100644 src/draive/agents/state.py create mode 100644 src/draive/agents/types.py diff --git a/constraints b/constraints index 7ace38a..846e569 100644 --- a/constraints +++ b/constraints @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile pyproject.toml -o constraints --all-extras +# uv --no-cache pip compile pyproject.toml -o constraints --all-extras annotated-types==0.6.0 # via pydantic anyio==4.3.0 @@ -41,7 +41,7 @@ mistralai==0.1.8 nodeenv==1.8.0 # via pyright numpy==1.26.4 -openai==1.19.0 +openai==1.23.1 orjson==3.10.1 # via mistralai packaging==24.0 @@ -58,7 +58,7 @@ pydantic-core==2.18.1 # via pydantic pygments==2.17.2 # via rich -pyright==1.1.358 +pyright==1.1.359 pytest==7.4.4 # via # pytest-asyncio @@ -67,13 +67,13 @@ pytest-asyncio==0.23.6 pytest-cov==4.1.0 pyyaml==6.0.1 # via bandit -regex==2023.12.25 +regex==2024.4.16 # via tiktoken requests==2.31.0 # via tiktoken rich==13.7.1 # via bandit -ruff==0.3.7 +ruff==0.4.0 setuptools==69.5.1 # via nodeenv sniffio==1.3.1 diff --git a/pyproject.toml b/pyproject.toml index d4afd3c..d2cb5cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ Repository = "https://github.com/miquido/draive.git" [project.optional-dependencies] dev = [ - "ruff~=0.3.0", + "ruff~=0.4.0", "pyright~=1.1", "bandit~=1.7", "pytest~=7.4", diff --git a/src/draive/__init__.py b/src/draive/__init__.py index 6237dea..f717c90 100644 --- a/src/draive/__init__.py +++ b/src/draive/__init__.py @@ -1,3 +1,11 @@ +from draive.agents import ( + Agent, + AgentException, + AgentFlow, + AgentState, + BaseAgent, + agent, +) from draive.conversation import ( Conversation, ConversationCompletion, @@ -106,8 +114,11 @@ VideoBase64Content, VideoContent, VideoURLContent, + merge_multimodal_content, ) from draive.utils import ( + AsyncStream, + AsyncStreamTask, allowing_early_exit, auto_retry, cache, @@ -116,12 +127,22 @@ ) __all__ = [ + "agent", + "agent", + "Agent", + "Agent", + "AgentException", + "AgentFlow", + "AgentState", "allowing_early_exit", "Argument", + "AsyncStream", + "AsyncStreamTask", "AudioBase64Content", "AudioContent", "AudioURLContent", "auto_retry", + "BaseAgent", "cache", "conversation_completion", "conversation_completion", @@ -162,6 +183,7 @@ "LMMCompletionStreamingUpdate", "load_env", "Memory", + "merge_multimodal_content", "Metric", "metrics_log_reporter", "MetricsTrace", diff --git a/src/draive/agents/__init__.py b/src/draive/agents/__init__.py new file mode 100644 index 0000000..86e5842 --- /dev/null +++ b/src/draive/agents/__init__.py @@ -0,0 +1,17 @@ +# from draive.agents.pool import AgentPool, AgentPoolCoordinator +from draive.agents.abc import BaseAgent +from draive.agents.agent import Agent, agent +from draive.agents.errors import AgentException +from draive.agents.flow import AgentFlow +from draive.agents.state import AgentState + +__all__ = [ + "agent", + "agent", + "Agent", + "Agent", + "AgentException", + "AgentFlow", + "AgentState", + "BaseAgent", +] diff --git a/src/draive/agents/abc.py b/src/draive/agents/abc.py new file mode 100644 index 0000000..30b9062 --- /dev/null +++ b/src/draive/agents/abc.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod + +from draive.agents.state import AgentState +from draive.parameters import ParametrizedData + +__all__ = [ + "BaseAgent", +] + + +class BaseAgent[State: ParametrizedData](ABC): + def __init__( + self, + agent_id: str, + name: str, + description: str, + ) -> None: + self.agent_id: str = agent_id + self.name: str = name + self.description: str = description + + def __eq__( + self, + other: object, + ) -> bool: + if isinstance(other, BaseAgent): + return self.agent_id == other.agent_id + else: + return False + + def __hash__(self) -> int: + return hash(self.agent_id) + + @abstractmethod + async def __call__( + self, + state: AgentState[State], + ) -> None: ... diff --git a/src/draive/agents/agent.py b/src/draive/agents/agent.py new file mode 100644 index 0000000..4a28fe9 --- /dev/null +++ b/src/draive/agents/agent.py @@ -0,0 +1,106 @@ +from collections.abc import Callable +from inspect import isfunction +from typing import final, overload +from uuid import uuid4 + +from draive.agents.abc import BaseAgent +from draive.agents.errors import AgentException +from draive.agents.state import AgentState +from draive.agents.types import AgentInvocation +from draive.helpers import freeze +from draive.metrics import ArgumentsTrace +from draive.parameters import ParametrizedData +from draive.scope import ctx + +__all__ = [ + "agent", + "Agent", +] + + +@final +class Agent[State: ParametrizedData](BaseAgent[State]): + def __init__( + self, + name: str, + description: str, + invoke: AgentInvocation[State], + ) -> None: + self.invoke: AgentInvocation[State] = invoke + super().__init__( + agent_id=uuid4().hex, + name=name, + description=description, + ) + + freeze(self) + + async def __call__( + self, + state: AgentState[State], + ) -> None: + invocation_id: str = uuid4().hex + with ctx.nested( + f"Agent|{self.name}", + metrics=[ + ArgumentsTrace.of( + agent_id=self.agent_id, + invocation_id=invocation_id, + state=state, + ) + ], + ): + try: + return await self.invoke(state) + + except Exception as exc: + raise AgentException( + "Agent invocation %s of %s failed due to an error: %s", + invocation_id, + self.agent_id, + exc, + ) from exc + + +@overload +def agent[State: ParametrizedData]( + invoke: AgentInvocation[State], + /, +) -> Agent[State]: ... + + +@overload +def agent[State: ParametrizedData]( + *, + name: str, + description: str | None = None, +) -> Callable[[AgentInvocation[State]], Agent[State]]: ... + + +@overload +def agent[State: ParametrizedData]( + *, + description: str, +) -> Callable[[AgentInvocation[State]], Agent[State]]: ... + + +def agent[State: ParametrizedData]( + invoke: AgentInvocation[State] | None = None, + *, + name: str | None = None, + description: str | None = None, +) -> Callable[[AgentInvocation[State]], Agent[State]] | Agent[State]: + def wrap( + invoke: AgentInvocation[State], + ) -> Agent[State]: + assert isfunction(invoke), "Agent has to be defined from function" # nosec: B101 + return Agent[State]( + name=name or invoke.__qualname__, + description=description or "", + invoke=invoke, + ) + + if invoke := invoke: + return wrap(invoke=invoke) + else: + return wrap diff --git a/src/draive/agents/errors.py b/src/draive/agents/errors.py new file mode 100644 index 0000000..bf78bf6 --- /dev/null +++ b/src/draive/agents/errors.py @@ -0,0 +1,7 @@ +__all__ = [ + "AgentException", +] + + +class AgentException(Exception): + pass diff --git a/src/draive/agents/flow.py b/src/draive/agents/flow.py new file mode 100644 index 0000000..2415cc0 --- /dev/null +++ b/src/draive/agents/flow.py @@ -0,0 +1,47 @@ +from asyncio import gather +from typing import final +from uuid import uuid4 + +from draive.agents.abc import BaseAgent +from draive.agents.state import AgentState +from draive.helpers import freeze +from draive.parameters import ParametrizedData + +__all__ = [ + "AgentFlow", +] + + +@final +class AgentFlow[State: ParametrizedData](BaseAgent[State]): + def __init__( + self, + *agents: tuple[BaseAgent[State], ...] | BaseAgent[State], + name: str, + description: str, + ) -> None: + super().__init__( + agent_id=uuid4().hex, + name=name, + description=description, + ) + self.agents: tuple[tuple[BaseAgent[State], ...] | BaseAgent[State], ...] = agents + + freeze(self) + + async def __call__( + self, + state: AgentState[State], + ) -> None: + for agent in self.agents: + match agent: + case [*agents]: + await gather( + *[agent(state) for agent in agents], + ) + + # case [agent]: + # await agent(state) + + case agent: + await agent(state) diff --git a/src/draive/agents/state.py b/src/draive/agents/state.py new file mode 100644 index 0000000..c83fa80 --- /dev/null +++ b/src/draive/agents/state.py @@ -0,0 +1,63 @@ +from asyncio import Lock +from collections.abc import Callable +from typing import Any + +from draive.parameters import ParametrizedData +from draive.types import MultimodalContent, MultimodalContentItem, merge_multimodal_content + +__all__ = [ + "AgentState", +] + + +class AgentState[State: ParametrizedData]: + def __init__( + self, + initial: State, + scratchpad: MultimodalContent | None = None, + ) -> None: + self._lock: Lock = Lock() + self._current: State = initial + self._scratchpad: tuple[MultimodalContentItem, ...] + match scratchpad: + case None: + self._scratchpad = () + case [*items]: + self._scratchpad = tuple(items) + case item: + self._scratchpad = (item,) + + @property + async def current(self) -> State: + async with self._lock: + return self._current + + @property + async def scratchpad(self) -> MultimodalContent: + async with self._lock: + return self._scratchpad + + async def extend_scratchpad( + self, + content: MultimodalContent, + ) -> MultimodalContent: + async with self._lock: + self._scratchpad = merge_multimodal_content(self._scratchpad, content) + return self._scratchpad + + async def apply( + self, + patch: Callable[[State], State], + ) -> State: + async with self._lock: + self._current = patch(self._current) + return self._current + + # TODO: find a way to generate signature Based on ParametrizedData + async def update( + self, + **kwargs: Any, + ) -> State: + async with self._lock: + self._current = self._current.updated(**kwargs) + return self._current diff --git a/src/draive/agents/types.py b/src/draive/agents/types.py new file mode 100644 index 0000000..f1c81c8 --- /dev/null +++ b/src/draive/agents/types.py @@ -0,0 +1,15 @@ +from typing import Protocol + +from draive.agents.state import AgentState +from draive.parameters import ParametrizedData + +__all__ = [ + "AgentInvocation", +] + + +class AgentInvocation[State: ParametrizedData](Protocol): + async def __call__( + self, + state: AgentState[State], + ) -> None: ... diff --git a/src/draive/metrics/function.py b/src/draive/metrics/function.py index a81aab0..56cd55d 100644 --- a/src/draive/metrics/function.py +++ b/src/draive/metrics/function.py @@ -68,7 +68,7 @@ def __add__(self, other: Self) -> Self: name="ExceptionGroup", exception=BaseExceptionGroup( "Multiple errors", - (*self.exception.exceptions, other.exception), # pyright: ignore[reportUnknownMemberType] + (*self.exception.exceptions, other.exception), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] ), ) else: diff --git a/src/draive/parameters/specification.py b/src/draive/parameters/specification.py index 94c35aa..21e559c 100644 --- a/src/draive/parameters/specification.py +++ b/src/draive/parameters/specification.py @@ -232,11 +232,11 @@ def parameter_specification( # noqa: C901, PLR0912 | typing.List # pyright: ignore[reportUnknownMemberType] # noqa: UP006 ): match get_args(resolved_annotation): - case (list_annotation,): + case (tuple_annotation,): specification = { "type": "array", "items": parameter_specification( - annotation=list_annotation, + annotation=tuple_annotation, description=None, globalns=globalns, localns=localns, @@ -248,6 +248,28 @@ def parameter_specification( # noqa: C901, PLR0912 specification = { "type": "array", } + case ( + builtins.tuple # pyright: ignore[reportUnknownMemberType] + | typing.Tuple # pyright: ignore[reportUnknownMemberType] # noqa: UP006 + ): + match get_args(resolved_annotation): + case (tuple_annotation, builtins.Ellipsis | types.EllipsisType): + specification = { + "type": "array", + "items": parameter_specification( + annotation=tuple_annotation, + description=None, + globalns=globalns, + localns=localns, + recursion_guard=recursion_guard, + ), + } + + # TODO: represent element type for finite tuples + case other: + specification = { + "type": "array", + } case typing.Literal: options: tuple[Any, ...] = get_args(resolved_annotation) diff --git a/src/draive/parameters/validation.py b/src/draive/parameters/validation.py index 10abf71..c5c148a 100644 --- a/src/draive/parameters/validation.py +++ b/src/draive/parameters/validation.py @@ -156,35 +156,81 @@ def validated(value: Any) -> Any: return validated -def _tuple_validator( +def _tuple_validator( # noqa: C901 options: tuple[Any, ...], + globalns: dict[str, Any] | None, + localns: dict[str, Any] | None, + recursion_guard: frozenset[type[Any]], verifier: ValueVerifier | None, ) -> ValueValidator: - # TODO: validate tuple elements - # | types.EllipsisType - # element_validators: list[Callable[[Any], Any]] = [ - # parameter_validator( - # option, - # verifier=verifier, - # module=module, - # ) - # for option in options - # ] - if verify := verifier: + match options: + case [element_annotation, builtins.Ellipsis | types.EllipsisType]: + validate_element: Callable[[Any], Any] = parameter_validator( + element_annotation, + globalns=globalns, + localns=localns, + recursion_guard=recursion_guard, + verifier=None, + ) - def validated(value: Any) -> Any: - if isinstance(value, tuple): - verify(value) - return value # pyright: ignore[reportUnknownVariableType] - else: - raise TypeError("Invalid value", value) - else: + if verify := verifier: + + def validated(value: Any) -> Any: + if isinstance(value, list | tuple): + validated: tuple[Any, ...] = tuple( + validate_element(element) + for element in value # pyright: ignore[reportUnknownVariableType] + ) + verify(validated) + return validated # pyright: ignore[reportUnknownVariableType] + else: + raise TypeError("Invalid value", value) + else: + + def validated(value: Any) -> Any: + if isinstance(value, list | tuple): + validated: tuple[Any, ...] = tuple( + validate_element(element) + for element in value # pyright: ignore[reportUnknownVariableType] + ) + return validated # pyright: ignore[reportUnknownVariableType] + else: + raise TypeError("Invalid value", value) + + case [*annotations]: + element_validators: list[Callable[[Any], Any]] = [ + parameter_validator( + annotation, + verifier=verifier, + globalns=globalns, + localns=localns, + recursion_guard=recursion_guard, + ) + for annotation in annotations + ] - def validated(value: Any) -> Any: - if isinstance(value, tuple): - return value # pyright: ignore[reportUnknownVariableType] + if verify := verifier: + + def validated(value: Any) -> Any: + if isinstance(value, list | tuple): + if len(value) != len(element_validators): # pyright: ignore[reportUnknownArgumentType] + raise TypeError("Invalid value", value) # pyright: ignore[reportUnknownArgumentType] + + validated: tuple[Any, ...] = tuple( + validation(value[idx]) + for idx, validation in enumerate(element_validators) + ) + verify(validated) + return validated # pyright: ignore[reportUnknownVariableType] + else: + raise TypeError("Invalid value", value) else: - raise TypeError("Invalid value", value) + + def validated(value: Any) -> Any: + if isinstance(value, tuple): + return value # pyright: ignore[reportUnknownVariableType] + else: + raise TypeError("Invalid value", value) return validated @@ -831,6 +877,9 @@ def parameter_validator[Value]( # noqa: PLR0911, C901 case builtins.tuple: return _tuple_validator( options=get_args(annotation), + globalns=globalns, + localns=localns, + recursion_guard=recursion_guard, verifier=verifier, ) @@ -907,6 +956,9 @@ def validated_missing(value: Any) -> Any: if draive_missing.is_missing(value): return value + elif value is None: + return draive_missing.MISSING + else: raise TypeError("Invalid value", value) diff --git a/src/draive/types/__init__.py b/src/draive/types/__init__.py index 63d54fb..5945fc2 100644 --- a/src/draive/types/__init__.py +++ b/src/draive/types/__init__.py @@ -2,7 +2,11 @@ from draive.types.images import ImageBase64Content, ImageContent, ImageURLContent from draive.types.memory import Memory, ReadOnlyMemory from draive.types.model import Model -from draive.types.multimodal import MultimodalContent +from draive.types.multimodal import ( + MultimodalContent, + MultimodalContentItem, + merge_multimodal_content, +) from draive.types.state import State from draive.types.updates import UpdateSend from draive.types.video import VideoBase64Content, VideoContent, VideoURLContent @@ -15,8 +19,10 @@ "ImageContent", "ImageURLContent", "Memory", + "merge_multimodal_content", "Model", "MultimodalContent", + "MultimodalContentItem", "ReadOnlyMemory", "State", "UpdateSend", diff --git a/src/draive/types/audio.py b/src/draive/types/audio.py index 786e543..6ef645c 100644 --- a/src/draive/types/audio.py +++ b/src/draive/types/audio.py @@ -1,3 +1,4 @@ +from draive.helpers import MISSING, Missing from draive.types.model import Model __all__ = [ @@ -9,10 +10,12 @@ class AudioURLContent(Model): audio_url: str + audio_transcription: str | Missing = MISSING class AudioBase64Content(Model): audio_base64: str + audio_transcription: str | Missing = MISSING AudioContent = AudioURLContent | AudioBase64Content diff --git a/src/draive/types/images.py b/src/draive/types/images.py index d035698..7d783c2 100644 --- a/src/draive/types/images.py +++ b/src/draive/types/images.py @@ -1,3 +1,4 @@ +from draive.helpers import MISSING, Missing from draive.types.model import Model __all__ = [ @@ -9,10 +10,12 @@ class ImageURLContent(Model): image_url: str + image_description: str | Missing = MISSING class ImageBase64Content(Model): image_base64: str + image_description: str | Missing = MISSING ImageContent = ImageURLContent | ImageBase64Content diff --git a/src/draive/types/model.py b/src/draive/types/model.py index 41eda6e..717ad3d 100644 --- a/src/draive/types/model.py +++ b/src/draive/types/model.py @@ -4,7 +4,7 @@ from typing import Any, Self from uuid import UUID -from draive.helpers import not_missing +from draive.helpers import Missing from draive.parameters import ParametrizedData __all__ = [ @@ -14,19 +14,15 @@ class ModelJSONEncoder(json.JSONEncoder): def default(self, o: object) -> Any: - if isinstance(o, datetime): + if isinstance(o, Missing): + return None + elif isinstance(o, datetime): return o.isoformat() elif isinstance(o, UUID): return o.hex else: return json.JSONEncoder.default(self, o) - def encode(self, o: object) -> Any: - if isinstance(o, dict): - return json.JSONEncoder.encode(self, {k: v for k, v in o.items() if not_missing(v)}) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] - else: - return json.JSONEncoder.encode(self, o) - class Model(ParametrizedData): @classmethod diff --git a/src/draive/types/multimodal.py b/src/draive/types/multimodal.py index f7b567d..ed74d2d 100644 --- a/src/draive/types/multimodal.py +++ b/src/draive/types/multimodal.py @@ -3,8 +3,24 @@ from draive.types.video import VideoContent __all__ = [ + "merge_multimodal_content", "MultimodalContent", + "MultimodalContentItem", ] MultimodalContentItem = VideoContent | ImageContent | AudioContent | str -MultimodalContent = list[MultimodalContentItem] | MultimodalContentItem +MultimodalContent = tuple[MultimodalContentItem, ...] | MultimodalContentItem + + +def merge_multimodal_content( + *content: MultimodalContent, +) -> tuple[MultimodalContentItem, ...]: + result: list[MultimodalContentItem] = [] + for part in content: + match part: + case [*parts]: + result.extend(parts) + case part: + result.append(part) + + return tuple(result) diff --git a/src/draive/types/video.py b/src/draive/types/video.py index 8436a8e..dee5dca 100644 --- a/src/draive/types/video.py +++ b/src/draive/types/video.py @@ -1,3 +1,4 @@ +from draive.helpers import MISSING, Missing from draive.types.model import Model __all__ = [ @@ -9,10 +10,12 @@ class VideoURLContent(Model): video_url: str + video_transcription: str | Missing = MISSING class VideoBase64Content(Model): video_base64: str + video_transcription: str | Missing = MISSING VideoContent = VideoURLContent | VideoBase64Content diff --git a/src/draive/utils/stream.py b/src/draive/utils/stream.py index 1e299fa..3f9a7cd 100644 --- a/src/draive/utils/stream.py +++ b/src/draive/utils/stream.py @@ -19,7 +19,7 @@ ] -class AsyncStream[Element: Model | str](AsyncIterator[Element]): +class AsyncStream[Element](AsyncIterator[Element]): def __init__(self) -> None: self._buffer: deque[Element] = deque() self._waiting: Future[Element] | None = None diff --git a/tests/test_model.py b/tests/test_model.py index 8ebc8af..0eff845 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -99,7 +99,7 @@ class MissingModel(Model): missing_model_instance: MissingModel = MissingModel( value=MISSING, ) -missing_model_json: str = "{}" +missing_model_json: str = '{"value": null}' def test_missing_encoding() -> None: @@ -162,30 +162,34 @@ def test_basic_decoding() -> None: content=ImageURLContent(image_url="https://miquido.com/image"), ) image_lmm_message_json: str = ( - '{"role": "assistant", "content": {"image_url": "https://miquido.com/image"}}' + '{"role": "assistant",' + ' "content": {"image_url": "https://miquido.com/image", "image_description": null}' + "}" ) audio_lmm_message_instance: LMMCompletionMessage = LMMCompletionMessage( role="assistant", content=AudioURLContent(audio_url="https://miquido.com/audio"), ) audio_lmm_message_json: str = ( - '{"role": "assistant", "content": {"audio_url": "https://miquido.com/audio"}}' + '{"role": "assistant",' + ' "content": {"audio_url": "https://miquido.com/audio", "audio_transcription": null}' + "}" ) mixed_lmm_message_instance: LMMCompletionMessage = LMMCompletionMessage( role="assistant", - content=[ + content=( AudioURLContent(audio_url="https://miquido.com/audio"), "string", ImageURLContent(image_url="https://miquido.com/image"), "content", - ], + ), ) mixed_lmm_message_json: str = ( '{"role": "assistant",' ' "content": [' - '{"audio_url": "https://miquido.com/audio"},' + '{"audio_url": "https://miquido.com/audio", "audio_transcription": null},' ' "string",' - ' {"image_url": "https://miquido.com/image"},' + ' {"image_url": "https://miquido.com/image", "image_description": null},' ' "content"' "]}" )