Skip to content

Commit

Permalink
Rework LMM interface
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed May 20, 2024
1 parent 324bbfd commit 639d510
Show file tree
Hide file tree
Showing 61 changed files with 2,111 additions and 2,082 deletions.
10 changes: 5 additions & 5 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mdurl==0.1.2
nodeenv==1.8.0
# via pyright
numpy==1.26.4
openai==1.26.0
openai==1.30.1
packaging==24.0
# via pytest
pbr==6.0.0
Expand All @@ -51,7 +51,7 @@ pydantic-core==2.18.2
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.361
pyright==1.1.363
pytest==7.4.4
# via
# pytest-asyncio
Expand All @@ -60,13 +60,13 @@ pytest-asyncio==0.23.6
pytest-cov==4.1.0
pyyaml==6.0.1
# via bandit
regex==2024.4.28
regex==2024.5.15
# via tiktoken
requests==2.31.0
# via tiktoken
rich==13.7.1
# via bandit
ruff==0.4.3
ruff==0.4.4
setuptools==69.5.1
# via nodeenv
sniffio==1.3.1
Expand All @@ -76,7 +76,7 @@ sniffio==1.3.1
# openai
stevedore==5.2.0
# via bandit
tiktoken==0.6.0
tiktoken==0.7.0
tqdm==4.66.4
# via openai
typing-extensions==4.11.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "draive"
version = "0.11.0"
version = "0.13.0"
readme = "README.md"
maintainers = [
{name = "Kacper Kaliński", email = "[email protected]"}
Expand Down
69 changes: 39 additions & 30 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from draive.conversation import (
Conversation,
ConversationCompletion,
ConversationCompletionStream,
ConversationMessage,
ConversationMessageChunk,
ConversationResponseStream,
conversation_completion,
lmm_conversation_completion,
)
Expand Down Expand Up @@ -42,14 +43,7 @@
split_sequence,
when_missing,
)
from draive.lmm import (
LMM,
LMMCompletion,
LMMCompletionStream,
LMMMessage,
LMMStreamingUpdate,
lmm_completion,
)
from draive.lmm import LMM, lmm_invocation
from draive.metrics import (
Metric,
MetricsTrace,
Expand Down Expand Up @@ -84,38 +78,46 @@
ScopeState,
ctx,
)
from draive.similarity import mmr_similarity, similarity
from draive.similarity import mmr_similarity_search, similarity_score, similarity_search
from draive.splitters import split_text
from draive.tokenization import TextTokenizer, Tokenization, count_text_tokens, tokenize_text
from draive.tools import (
Tool,
Toolbox,
ToolCallContext,
ToolCallStatus,
ToolCallUpdate,
ToolException,
ToolsUpdatesContext,
ToolStatusStreaming,
tool,
)
from draive.types import (
AudioBase64Content,
AudioContent,
AudioDataContent,
AudioURLContent,
ImageBase64Content,
ImageContent,
ImageDataContent,
ImageURLContent,
Instruction,
LMMCompletion,
LMMCompletionChunk,
LMMContextElement,
LMMInput,
LMMInstruction,
LMMOutputStream,
LMMOutputStreamChunk,
LMMToolRequest,
LMMToolResponse,
Memory,
Model,
MultimodalContent,
ReadOnlyMemory,
State,
ToolCallStatus,
VideoBase64Content,
VideoContent,
VideoDataContent,
VideoURLContent,
has_media,
is_multimodal_content,
merge_multimodal_content,
multimodal_content_string,
)
from draive.utils import (
AsyncStream,
Expand All @@ -142,6 +144,7 @@
"AsyncStreamTask",
"AudioBase64Content",
"AudioContent",
"AudioDataContent",
"AudioURLContent",
"auto_retry",
"BaseAgent",
Expand All @@ -151,7 +154,8 @@
"Conversation",
"Conversation",
"ConversationCompletion",
"ConversationCompletionStream",
"ConversationMessageChunk",
"ConversationResponseStream",
"ConversationMessage",
"count_text_tokens",
"ctx",
Expand All @@ -168,24 +172,28 @@
"getenv_float",
"getenv_int",
"getenv_str",
"has_media",
"ImageBase64Content",
"ImageContent",
"ImageDataContent",
"ImageGeneration",
"ImageGenerator",
"ImageURLContent",
"Instruction",
"is_missing",
"is_multimodal_content",
"lmm_completion",
"lmm_invocation",
"lmm_conversation_completion",
"LMM",
"LMMCompletion",
"LMMMessage",
"LMMCompletionStream",
"LMMStreamingUpdate",
"LMMCompletionChunk",
"LMMContextElement",
"LMMInput",
"LMMInstruction",
"LMMOutputStream",
"LMMOutputStreamChunk",
"LMMToolRequest",
"LMMToolResponse",
"load_env",
"Memory",
"merge_multimodal_content",
"Metric",
"metrics_log_reporter",
"MetricsTrace",
Expand All @@ -199,11 +207,10 @@
"MistralClient",
"MistralEmbeddingConfig",
"MistralException",
"mmr_similarity",
"mmr_similarity_search",
"Model",
"ModelGeneration",
"ModelGenerator",
"multimodal_content_string",
"MultimodalContent",
"not_missing",
"openai_embed_text",
Expand All @@ -221,7 +228,8 @@
"ScopeDependency",
"ScopeState",
"setup_logging",
"similarity",
"similarity_score",
"similarity_search",
"split_sequence",
"split_text",
"State",
Expand All @@ -237,13 +245,14 @@
"Toolbox",
"ToolCallContext",
"ToolCallStatus",
"ToolCallUpdate",
"ToolCallStatus",
"ToolException",
"ToolException",
"ToolsUpdatesContext",
"ToolStatusStreaming",
"traced",
"VideoBase64Content",
"VideoContent",
"VideoDataContent",
"VideoURLContent",
"when_missing",
"with_early_exit",
Expand Down
6 changes: 3 additions & 3 deletions src/draive/agents/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from draive.helpers import freeze
from draive.parameters import ParametrizedData
from draive.scope import ctx
from draive.types import MultimodalContent, merge_multimodal_content
from draive.types import MultimodalContent

__all__ = [
"AgentFlow",
Expand Down Expand Up @@ -42,7 +42,7 @@ async def __call__(
with ctx.updated(current_scratchpad):
match agent:
case [*agents]:
merged_note: MultimodalContent = merge_multimodal_content(
merged_note: MultimodalContent = MultimodalContent.of(
*[
scratchpad_note
for scratchpad_note in await gather(
Expand All @@ -61,4 +61,4 @@ async def __call__(
current_scratchpad = current_scratchpad.extended(scratchpad_note)
scratchpad_notes.append(scratchpad_note)

return merge_multimodal_content(*scratchpad_notes)
return MultimodalContent.of(*scratchpad_notes)
15 changes: 6 additions & 9 deletions src/draive/agents/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from draive.parameters import ParametrizedData
from draive.scope import ctx
from draive.types import MultimodalContent, MultimodalContentItem, State
from draive.types import MultimodalContent, State

__all__ = [
"AgentState",
Expand All @@ -24,13 +24,12 @@ def prepare(
) -> Self:
match content:
case None:
return cls(content=())
case [*items]:
return cls(content=tuple(items))
return cls(content=MultimodalContent.of())

case item:
return cls(content=(item,))
return cls(content=item)

content: tuple[MultimodalContentItem, ...] = ()
content: MultimodalContent = MultimodalContent.of()

def extended(
self,
Expand All @@ -39,10 +38,8 @@ def extended(
match content:
case None:
return self
case [*items]:
return self.__class__(content=(*self.content, *items))
case item:
return self.__class__(content=(*self.content, item))
return self.__class__(content=MultimodalContent.of(*self.content, item))


class AgentState[State: ParametrizedData]:
Expand Down
11 changes: 8 additions & 3 deletions src/draive/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from draive.conversation.call import conversation_completion
from draive.conversation.completion import ConversationCompletion, ConversationCompletionStream
from draive.conversation.completion import ConversationCompletion
from draive.conversation.lmm import lmm_conversation_completion
from draive.conversation.message import ConversationMessage
from draive.conversation.model import (
ConversationMessage,
ConversationMessageChunk,
ConversationResponseStream,
)
from draive.conversation.state import Conversation

__all__ = [
"conversation_completion",
"Conversation",
"ConversationCompletion",
"ConversationCompletionStream",
"ConversationMessageChunk",
"ConversationResponseStream",
"ConversationMessage",
"lmm_conversation_completion",
]
Loading

0 comments on commit 639d510

Please sign in to comment.