diff --git a/haystack_experimental/components/generators/anthropic/chat/chat_generator.py b/haystack_experimental/components/generators/anthropic/chat/chat_generator.py index 8ec67df5..b660c71e 100644 --- a/haystack_experimental/components/generators/anthropic/chat/chat_generator.py +++ b/haystack_experimental/components/generators/anthropic/chat/chat_generator.py @@ -4,6 +4,7 @@ import json import logging +from base64 import b64encode from typing import Any, Callable, Dict, List, Optional, Tuple, Type from haystack import component, default_from_dict @@ -11,7 +12,7 @@ from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace -from haystack_experimental.dataclasses import ChatMessage, ToolCall +from haystack_experimental.dataclasses import ChatMessage, ToolCall, ByteStream from haystack_experimental.dataclasses.chat_message import ChatRole, ToolCallResult from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace @@ -38,7 +39,9 @@ # - AnthropicChatGenerator fails with ImportError at init (due to anthropic_integration_import.check()). if anthropic_integration_import.is_successful(): - chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = AnthropicChatGeneratorBase + chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = ( + AnthropicChatGeneratorBase + ) else: chatgenerator_base_class: Type[object] = object # type: ignore[no-redef] @@ -57,7 +60,9 @@ def _update_anthropic_message_with_tool_call_results( for tool_call_result in tool_call_results: if tool_call_result.origin.id is None: - raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.") + raise ValueError( + "`ToolCall` must have a non-null `id` attribute to be used with Anthropic." + ) anthropic_msg["content"].append( { "type": "tool_result", @@ -68,7 +73,9 @@ def _update_anthropic_message_with_tool_call_results( ) -def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]: +def _convert_tool_calls_to_anthropic_format( + tool_calls: List[ToolCall], +) -> List[Dict[str, Any]]: """ Convert a list of tool calls to the format expected by Anthropic Chat API. @@ -78,7 +85,9 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[ anthropic_tool_calls = [] for tc in tool_calls: if tc.id is None: - raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.") + raise ValueError( + "`ToolCall` must have a non-null `id` attribute to be used with Anthropic." + ) anthropic_tool_calls.append( { "type": "tool_use", @@ -90,6 +99,44 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[ return anthropic_tool_calls +def _convert_media_to_anthropic_format(media: List[ByteStream]) -> List[Dict[str, Any]]: + """ + Convert a list of media to the format expected by Anthropic Chat API. + + :param media: The list of ByteStreams to convert. + :return: A list of dictionaries in the format expected by Anthropic API. + """ + anthropic_media = [] + for item in media: + if item.type == "image": + anthropic_media.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": item.mime_type, + "data": b64encode(item.data).decode("utf-8"), + }, + } + ) + elif item.type == "application" and item.subtype == "pdf": + anthropic_media.append( + { + "type": "document", + "source": { + "type": "base64", + "media_type": item.mime_type, + "data": b64encode(item.data).decode("utf-8"), + }, + } + ) + else: + raise ValueError( + f"Unsupported media type '{item.mime_type}' for Anthropic completions." + ) + return anthropic_media + + def _convert_messages_to_anthropic_format( messages: List[ChatMessage], ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: @@ -119,10 +166,17 @@ def _convert_messages_to_anthropic_format( anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": []} - if message.texts and message.texts[0]: - anthropic_msg["content"].append({"type": "text", "text": message.texts[0]}) + if message.texts: + for item in message.texts: + anthropic_msg["content"].append({"type": "text", "text": item}) + if message.media: + anthropic_msg["content"] += _convert_media_to_anthropic_format( + message.media + ) if message.tool_calls: - anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls) + anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format( + message.tool_calls + ) if message.tool_call_results: results = message.tool_call_results.copy() @@ -136,7 +190,8 @@ def _convert_messages_to_anthropic_format( if not anthropic_msg["content"]: raise ValueError( - "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + "A `ChatMessage` must contain at least one `TextContent`, `MediaContent`, " + "`ToolCall`, or `ToolCallResult`." ) anthropic_non_system_messages.append(anthropic_msg) @@ -250,7 +305,9 @@ def to_dict(self) -> Dict[str, Any]: The serialized component as a dictionary. """ serialized = super(AnthropicChatGenerator, self).to_dict() - serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + serialized["init_parameters"]["tools"] = ( + [tool.to_dict() for tool in self.tools] if self.tools else None + ) return serialized @classmethod @@ -267,11 +324,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator": init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + data["init_parameters"]["streaming_callback"] = deserialize_callable( + serialized_callback_handler + ) return default_from_dict(cls, data) - def _convert_chat_completion_to_chat_message(self, anthropic_response: Any) -> ChatMessage: + def _convert_chat_completion_to_chat_message( + self, anthropic_response: Any + ) -> ChatMessage: """ Converts the response from the Anthropic API to a ChatMessage. """ @@ -343,15 +404,22 @@ def _convert_streaming_chunks_to_chat_message( full_content += delta.get("text", "") elif delta.get("type") == "input_json_delta" and current_tool_call: current_tool_call["arguments"] += delta.get("partial_json", "") - elif chunk_type == "message_delta": # noqa: SIM102 (prefer nested if statement here for readability) - if chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" and current_tool_call: + elif ( + chunk_type == "message_delta" + ): # noqa: SIM102 (prefer nested if statement here for readability) + if ( + chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" + and current_tool_call + ): try: # arguments is a string, convert to json tool_calls.append( ToolCall( id=current_tool_call.get("id"), tool_name=str(current_tool_call.get("name")), - arguments=json.loads(current_tool_call.get("arguments", {})), + arguments=json.loads( + current_tool_call.get("arguments", {}) + ), ) ) except json.JSONDecodeError: @@ -370,7 +438,9 @@ def _convert_streaming_chunks_to_chat_message( { "model": model, "index": 0, - "finish_reason": last_chunk_meta.get("delta", {}).get("stop_reason", None), + "finish_reason": last_chunk_meta.get("delta", {}).get( + "stop_reason", None + ), "usage": last_chunk_meta.get("usage", {}), } ) @@ -405,12 +475,16 @@ def run( disallowed_params, self.ALLOWED_PARAMS, ) - generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} + generation_kwargs = { + k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS + } tools = tools or self.tools if tools: _check_duplicate_tool_names(tools) - system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages) + system_messages, non_system_messages = _convert_messages_to_anthropic_format( + messages + ) anthropic_tools = ( [ { @@ -447,7 +521,9 @@ def run( "content_block_delta", "message_delta", ]: - streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(chunk) + streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk( + chunk + ) chunks.append(streaming_chunk) if streaming_callback: streaming_callback(streaming_chunk) @@ -455,4 +531,6 @@ def run( completion = self._convert_streaming_chunks_to_chat_message(chunks, model) return {"replies": [completion]} else: - return {"replies": [self._convert_chat_completion_to_chat_message(response)]} + return { + "replies": [self._convert_chat_completion_to_chat_message(response)] + } diff --git a/haystack_experimental/components/generators/chat/openai.py b/haystack_experimental/components/generators/chat/openai.py index 19337b46..c49c7ca4 100644 --- a/haystack_experimental/components/generators/chat/openai.py +++ b/haystack_experimental/components/generators/chat/openai.py @@ -4,6 +4,7 @@ import json import os +from base64 import b64encode from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -19,7 +20,14 @@ from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice -from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall +from haystack_experimental.dataclasses import ( + ChatMessage, + Tool, + ToolCall, + TextContent, + ChatRole, + MediaContent, +) from haystack_experimental.dataclasses.streaming_chunk import ( AsyncStreamingCallbackT, StreamingCallbackT, @@ -34,53 +42,81 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]: """ Convert a message to the format expected by OpenAI's Chat API. """ - text_contents = message.texts - tool_calls = message.tool_calls - tool_call_results = message.tool_call_results - - if not text_contents and not tool_calls and not tool_call_results: - raise ValueError( - "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." - ) - elif len(text_contents) + len(tool_call_results) > 1: + openai_msg: Dict[str, Any] = {"role": message.role.value} + if len(message) == 0: raise ValueError( - "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`." + "ChatMessage must contain at least one `TextContent`, " + "`MediaContent`, `ToolCall`, or `ToolCallResult`." ) - - openai_msg: Dict[str, Any] = {"role": message._role.value} - - if tool_call_results: - result = tool_call_results[0] - if result.origin.id is None: + if len(message) == 1 and isinstance(message.content[0], TextContent): + openai_msg["content"] = message.content[0].text + elif message.tool_call_result: + # Tool call results should only be included for ChatRole.TOOL messages + # and should not include any other content + if message.role != ChatRole.TOOL: + raise ValueError( + "Tool call results should only be included for tool messages." + ) + if len(message) > 1: + raise ValueError( + "Tool call results should not be included with other content." + ) + if message.tool_call_result.origin.id is None: raise ValueError( "`ToolCall` must have a non-null `id` attribute to be used with OpenAI." ) - openai_msg["content"] = result.result - openai_msg["tool_call_id"] = result.origin.id - # OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field - return openai_msg - - if text_contents: - openai_msg["content"] = text_contents[0] - if tool_calls: - openai_tool_calls = [] - for tc in tool_calls: - if tc.id is None: + openai_msg["content"] = message.tool_call_result.result + openai_msg["tool_call_id"] = message.tool_call_result.origin.id + else: + openai_msg["content"] = [] + for item in message.content: + if isinstance(item, TextContent): + openai_msg["content"].append({"type": "text", "text": item.text}) + elif isinstance(item, MediaContent): + match item.media.type: + case "image": + base64_data = b64encode(item.media.data).decode("utf-8") + url = f"data:{item.media.mime_type};base64,{base64_data}" + openai_msg["content"].append( + { + "type": "image_url", + "image_url": { + "url": url, + "detail": item.media.meta.get("detail", "auto"), + }, + } + ) + case _: + raise ValueError( + f"Unsupported media type '{item.media.mime_type}' for OpenAI completions." + ) + elif isinstance(item, ToolCall): + if message.role != ChatRole.ASSISTANT: + raise ValueError( + "Tool calls should only be included for assistant messages." + ) + if item.id is None: + raise ValueError( + "`ToolCall` must have a non-null `id` attribute to be used with OpenAI." + ) + openai_msg.setdefault("tool_calls", []).append( + { + "id": item.id, + "type": "function", + "function": { + "name": item.tool_name, + "arguments": json.dumps(item.arguments, ensure_ascii=False), + }, + } + ) + else: raise ValueError( - "`ToolCall` must have a non-null `id` attribute to be used with OpenAI." + f"Unsupported content type '{type(item).__name__}' for OpenAI completions." ) - openai_tool_calls.append( - { - "id": tc.id, - "type": "function", - # We disable ensure_ascii so special chars like emojis are not converted - "function": { - "name": tc.tool_name, - "arguments": json.dumps(tc.arguments, ensure_ascii=False), - }, - } - ) - openai_msg["tool_calls"] = openai_tool_calls + + if message.name: + openai_msg["name"] = message.name + return openai_msg diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index 78a97618..2777d92d 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.dataclasses.byte_stream import ByteStream from haystack_experimental.dataclasses.chat_message import ( ChatMessage, ChatMessageContentT, ChatRole, + MediaContent, TextContent, ToolCall, ToolCallResult, @@ -18,12 +20,14 @@ __all__ = [ "AsyncStreamingCallbackT", + "ByteStream", "ChatMessage", + "ChatMessageContentT", "ChatRole", + "MediaContent", "StreamingCallbackT", + "TextContent", "ToolCall", "ToolCallResult", - "TextContent", - "ChatMessageContentT", "Tool", ] diff --git a/haystack_experimental/dataclasses/byte_stream.py b/haystack_experimental/dataclasses/byte_stream.py new file mode 100644 index 00000000..268b3161 --- /dev/null +++ b/haystack_experimental/dataclasses/byte_stream.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +""" +Data classes for representing binary data in the Haystack API. The ByteStream class can be used to represent binary data +in the API, and can be converted to and from base64 encoded strings, dictionaries, and files. This is particularly +useful for representing media files in chat messages. +""" + +import logging +import mimetypes +from base64 import b64encode, b64decode +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional + + +logger = logging.getLogger(__name__) + + +@dataclass +class ByteStream: + """ + Base data class representing a binary object in the Haystack API. + """ + + data: bytes + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + mime_type: Optional[str] = field(default=None) + + @property + def type(self) -> Optional[str]: + """ + Return the type of the ByteStream. This is the first part of the mime type, or None if the mime type is not set. + + :return: The type of the ByteStream. + """ + if self.mime_type: + return self.mime_type.split("/", maxsplit=1)[0] + return None + + @property + def subtype(self) -> Optional[str]: + """ + Return the subtype of the ByteStream. This is the second part of the mime type, + or None if the mime type is not set. + + :return: The subtype of the ByteStream. + """ + if self.mime_type: + return self.mime_type.split("/", maxsplit=1)[-1] + return None + + def to_file(self, destination_path: Path): + """ + Write the ByteStream to a file. Note: the metadata will be lost. + + :param destination_path: The path to write the ByteStream to. + """ + with open(destination_path, "wb") as fd: + fd.write(self.data) + + @classmethod + def from_file_path( + cls, filepath: Path, mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None + ) -> "ByteStream": + """ + Create a ByteStream from the contents read from a file. + + :param filepath: A valid path to a file. + :param mime_type: The mime type of the file. + :param meta: Additional metadata to be stored with the ByteStream. + """ + if mime_type is None: + mime_type = mimetypes.guess_type(filepath)[0] + if mime_type is None: + logger.warning("Could not determine mime type for file %s", filepath) + + with open(filepath, "rb") as fd: + return cls(data=fd.read(), mime_type=mime_type, meta=meta or {}) + + @classmethod + def from_string( + cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None + ) -> "ByteStream": + """ + Create a ByteStream encoding a string. + + :param text: The string to encode + :param encoding: The encoding used to convert the string into bytes + :param mime_type: The mime type of the file. + :param meta: Additional metadata to be stored with the ByteStream. + """ + return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {}) + + def to_string(self, encoding: str = "utf-8") -> str: + """ + Convert the ByteStream to a string, metadata will not be included. + + :param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8". + :returns: The string representation of the ByteStream. + :raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. + """ + return self.data.decode(encoding) + + @classmethod + def from_base64( + cls, + base64_string: str, + encoding: str = "utf-8", + meta: Optional[Dict[str, Any]] = None, + mime_type: Optional[str] = None, + ) -> "ByteStream": + """ + Create a ByteStream from a base64 encoded string. + + :param base64_string: The base64 encoded string representation of the ByteStream data. + :param encoding: The encoding used to convert the base64 string into bytes. + :param meta: Additional metadata to be stored with the ByteStream. + :param mime_type: The mime type of the file. + :returns: A new ByteStream instance. + """ + return cls(data=b64decode(base64_string.encode(encoding)), meta=meta or {}, mime_type=mime_type) + + def to_base64(self, encoding: str = "utf-8") -> str: + """ + Convert the ByteStream data to a base64 encoded string. + + :returns: The base64 encoded string representation of the ByteStream data. + """ + return b64encode(self.data).decode(encoding) + + @classmethod + def from_dict(cls, data: Dict[str, Any], encoding: str = "utf-8") -> "ByteStream": + """ + Create a ByteStream from a dictionary. + + :param data: The dictionary representation of the ByteStream. + :param encoding: The encoding used to convert the base64 string into bytes. + :returns: A new ByteStream instance. + """ + return cls.from_base64(data["data"], encoding=encoding, meta=data.get("meta"), mime_type=data.get("mime_type")) + + def to_dict(self, encoding: str = "utf-8"): + """ + Convert the ByteStream to a dictionary. + + :param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8". + :returns: The dictionary representation of the ByteStream. + :raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. + """ + return {"data": self.to_base64(encoding=encoding), "meta": self.meta, "mime_type": self.mime_type} diff --git a/haystack_experimental/dataclasses/chat_message.py b/haystack_experimental/dataclasses/chat_message.py index 8490483a..24bebf14 100644 --- a/haystack_experimental/dataclasses/chat_message.py +++ b/haystack_experimental/dataclasses/chat_message.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union +from haystack_experimental.dataclasses import ByteStream + class ChatRole(str, Enum): """ @@ -66,7 +68,18 @@ class TextContent: text: str -ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult] +@dataclass +class MediaContent: + """ + The media content of a chat message. + + :param media: The media content of the message. + """ + + media: ByteStream + + +ChatMessageContentT = Union[TextContent, MediaContent, ToolCall, ToolCallResult] @dataclass @@ -82,10 +95,26 @@ class ChatMessage: _role: ChatRole _content: Sequence[ChatMessageContentT] _meta: Dict[str, Any] = field(default_factory=dict, hash=False) + _name: Optional[str] = None def __len__(self): return len(self._content) + @property + def name(self) -> Optional[str]: + """ + Returns the name for the message participant, if provided. + + """ + return self._name + + @property + def content(self) -> Sequence[ChatMessageContentT]: + """ + Returns the content of the message. + """ + return self._content + @property def role(self) -> ChatRole: """ @@ -116,6 +145,15 @@ def text(self) -> Optional[str]: return texts[0] return None + @property + def media(self) -> List[ByteStream]: + """ + Returns the list of all media content contained in the message. + + :return: List of ByteStream objects. + """ + return [content.media for content in self._content if isinstance(content, MediaContent)] + @property def tool_calls(self) -> List[ToolCall]: """ @@ -161,37 +199,47 @@ def is_from(self, role: ChatRole) -> bool: def from_user( cls, text: str, + media: Optional[Sequence[ByteStream]] = None, + name: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, ) -> "ChatMessage": """ Create a message from the user. :param text: The text content of the message. + :param media: The media contents of the message, if any. + :param name: An optional name for the message participant. :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}) + media_contents = [MediaContent(media=media) for media in media] if media else [] + return cls( + _role=ChatRole.USER, _content=[TextContent(text=text), *media_contents], _name=name, _meta=meta or {} + ) @classmethod def from_system( cls, text: str, + name: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, ) -> "ChatMessage": """ Create a message from the system. :param text: The text content of the message. + :param name: An optional name for the message participant. :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ - return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}) + return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _name=name, _meta=meta or {}) @classmethod def from_assistant( cls, text: Optional[str] = None, tool_calls: Optional[List[ToolCall]] = None, + name: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, ) -> "ChatMessage": """ @@ -199,6 +247,7 @@ def from_assistant( :param text: The text content of the message. :param tool_calls: The Tool calls to include in the message. + :param name: An optional name for the message participant. :param meta: Additional metadata associated with the message. :returns: A new ChatMessage instance. """ @@ -208,7 +257,7 @@ def from_assistant( if tool_calls: content.extend(tool_calls) - return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}) + return cls(_role=ChatRole.ASSISTANT, _content=content, _name=name, _meta=meta or {}) @classmethod def from_tool( @@ -240,14 +289,14 @@ def to_dict(self) -> Dict[str, Any]: :returns: Serialized version of the object. """ - serialized: Dict[str, Any] = {} - serialized["_role"] = self._role.value - serialized["_meta"] = self._meta + serialized: Dict[str, Any] = {"_role": self._role.value, "_name": self._name, "_meta": self._meta} content: List[Dict[str, Any]] = [] for part in self._content: if isinstance(part, TextContent): content.append({"text": part.text}) + elif isinstance(part, MediaContent): + content.append({"media": part.media.to_dict()}) elif isinstance(part, ToolCall): content.append({"tool_call": asdict(part)}) elif isinstance(part, ToolCallResult): @@ -275,6 +324,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": for part in data["_content"]: if "text" in part: content.append(TextContent(text=part["text"])) + elif "media" in part: + content.append(MediaContent(media=ByteStream.from_dict(part["media"]))) elif "tool_call" in part: content.append(ToolCall(**part["tool_call"])) elif "tool_call_result" in part: diff --git a/test/components/generators/anthropic/test_anthropic.py b/test/components/generators/anthropic/test_anthropic.py index 718b4790..8f81d274 100644 --- a/test/components/generators/anthropic/test_anthropic.py +++ b/test/components/generators/anthropic/test_anthropic.py @@ -4,6 +4,7 @@ import json import logging import os +from base64 import b64encode from unittest.mock import patch import pytest @@ -19,7 +20,7 @@ AnthropicChatGenerator, _convert_messages_to_anthropic_format, ) -from haystack_experimental.dataclasses import ChatMessage, ChatRole, Tool, ToolCall +from haystack_experimental.dataclasses import ChatMessage, ChatRole, Tool, ToolCall, ByteStream from haystack_experimental.dataclasses.chat_message import ToolCallResult @@ -627,6 +628,41 @@ def test_convert_message_to_anthropic_format(self): [{"role": "assistant", "content": [{"type": "text", "text": "I have an answer"}]}], ) + messages = [ + ChatMessage.from_user( + text="Multimodal example", + media=[ByteStream(b"data", mime_type="image/png"), ByteStream(b"data2", mime_type="application/pdf")], + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Multimodal example"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": b64encode(b"data").decode("utf-8"), + }, + }, + { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": b64encode(b"data2").decode("utf-8"), + }, + }, + ], + }, + ], + ) + messages = [ ChatMessage.from_assistant( tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] @@ -655,7 +691,10 @@ def test_convert_message_to_anthropic_format(self): [ { "role": "assistant", - "content": [{"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}], + "content": [ + {"type": "text", "text": ""}, + {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, + ], } ], ) @@ -739,6 +778,7 @@ def test_convert_message_to_anthropic_format_complex(self): { "role": "assistant", "content": [ + {"type": "text", "text": ""}, {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, {"type": "tool_use", "id": "456", "name": "math", "input": {"expression": "2+2"}}, ], diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 4a835f9f..3ec66d79 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from base64 import b64encode from unittest.mock import MagicMock, patch import pytest @@ -31,6 +32,7 @@ ToolCall, ChatRole, TextContent, + ByteStream, ) from haystack_experimental.components.generators.chat.openai import ( OpenAIChatGenerator, @@ -719,12 +721,36 @@ def test_convert_message_to_openai_format(self): "content": "I have an answer", } + message = ChatMessage.from_user( + text="Hello", + media=[ + ByteStream( + data=b"test data", meta={"detail": "low"}, mime_type="image/png" + ) + ], + ) + assert _convert_message_to_openai_format(message) == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64," + + b64encode(b"test data").decode("utf-8"), + "detail": "low", + }, + }, + ], + } + message = ChatMessage.from_assistant( tool_calls=[ ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) ] ) assert _convert_message_to_openai_format(message) == { + "content": [], "role": "assistant", "tool_calls": [ { @@ -751,16 +777,6 @@ def test_convert_message_to_openai_invalid(self): with pytest.raises(ValueError): _convert_message_to_openai_format(message) - message = ChatMessage( - _role=ChatRole.ASSISTANT, - _content=[ - TextContent(text="I have an answer"), - TextContent(text="I have another answer"), - ], - ) - with pytest.raises(ValueError): - _convert_message_to_openai_format(message) - tool_call_null_id = ToolCall( id=None, tool_name="weather", arguments={"city": "Paris"} ) diff --git a/test/dataclasses/__init__.py b/test/dataclasses/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/dataclasses/test_byte_stream.py b/test/dataclasses/test_byte_stream.py new file mode 100644 index 00000000..1f38e8bb --- /dev/null +++ b/test/dataclasses/test_byte_stream.py @@ -0,0 +1,91 @@ +import pytest +from base64 import b64encode +from pathlib import Path +from unittest.mock import mock_open, patch + +from haystack_experimental.dataclasses.byte_stream import ByteStream + +@pytest.fixture +def byte_stream(): + test_data = b"test data" + test_meta = {"key": "value"} + test_mime = "text/plain" + return ByteStream(data=test_data, meta=test_meta, mime_type=test_mime) + +def test_init(byte_stream): + assert byte_stream.data == b"test data" + assert byte_stream.meta == {"key": "value"} + assert byte_stream.mime_type == "text/plain" + +def test_type_property(byte_stream): + assert byte_stream.type == "text" + stream_without_mime = ByteStream(data=b"test data") + assert stream_without_mime.type is None + +def test_subtype_property(byte_stream): + assert byte_stream.subtype == "plain" + stream_without_mime = ByteStream(data=b"test data") + assert stream_without_mime.subtype is None + +@patch("builtins.open", new_callable=mock_open) +def test_to_file(mock_file, byte_stream): + path = Path("test.txt") + byte_stream.to_file(path) + mock_file.assert_called_once_with(path, "wb") + mock_file().write.assert_called_once_with(b"test data") + +@patch("builtins.open", new_callable=mock_open, read_data=b"test data") +def test_from_file_path(mock_file): + path = Path("test.txt") + with patch("mimetypes.guess_type", return_value=("text/plain", None)): + byte_stream = ByteStream.from_file_path(path) + assert byte_stream.data == b"test data" + assert byte_stream.mime_type == "text/plain" + +@patch("mimetypes.guess_type", return_value=(None, None)) +@patch("haystack_experimental.dataclasses.byte_stream.logger.warning") +def test_from_file_path_unknown_mime(mock_warning, _, byte_stream): + path = Path("test.txt") + with patch("builtins.open", new_callable=mock_open, read_data=b"test data"): + byte_stream = ByteStream.from_file_path(path) + assert byte_stream.mime_type is None + mock_warning.assert_called_once() + +def test_from_string(): + text = "Hello, World!" + byte_stream = ByteStream.from_string(text, mime_type="text/plain") + assert byte_stream.data == text.encode("utf-8") + assert byte_stream.mime_type == "text/plain" + +def test_to_string(): + byte_stream = ByteStream(data=b"Hello, World!") + assert byte_stream.to_string() == "Hello, World!" + +def test_from_base64(): + base64_string = b64encode(b"test data").decode("utf-8") + byte_stream = ByteStream.from_base64(base64_string, mime_type="text/plain") + assert byte_stream.data == b"test data" + assert byte_stream.mime_type == "text/plain" + +def test_to_base64(byte_stream): + expected = b64encode(b"test data").decode("utf-8") + assert byte_stream.to_base64() == expected + +def test_from_dict(): + data = { + "data": b64encode(b"test data").decode("utf-8"), + "meta": {"key": "value"}, + "mime_type": "text/plain", + } + byte_stream = ByteStream.from_dict(data) + assert byte_stream.data == b"test data" + assert byte_stream.meta == {"key": "value"} + assert byte_stream.mime_type == "text/plain" + +def test_to_dict(byte_stream): + expected = { + "data": b64encode(b"test data").decode("utf-8"), + "meta": {"key": "value"}, + "mime_type": "text/plain", + } + assert byte_stream.to_dict() == expected \ No newline at end of file diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 90f046f9..f46fca74 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -1,9 +1,13 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from base64 import b64encode + import pytest -from haystack_experimental.dataclasses import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent +from haystack_experimental.dataclasses import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent, \ + MediaContent, ByteStream + def test_tool_call_init(): tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) @@ -21,6 +25,10 @@ def test_text_content_init(): tc = TextContent(text="Hello") assert tc.text == "Hello" +def test_media_content_init(): + mc = MediaContent(media=ByteStream(data=b"media data", mime_type="image/png")) + assert mc.media.data == b"media data" + assert mc.media.mime_type == "image/png" def test_from_assistant_with_valid_content(): text = "Hello, how can I assist you?" @@ -32,6 +40,7 @@ def test_from_assistant_with_valid_content(): assert message.text == text assert message.texts == [text] + assert not message.media assert not message.tool_calls assert not message.tool_call assert not message.tool_call_results @@ -51,6 +60,7 @@ def test_from_assistant_with_tool_calls(): assert not message.texts assert not message.text + assert not message.media assert not message.tool_call_results assert not message.tool_call_result @@ -65,6 +75,24 @@ def test_from_user_with_valid_content(): assert message.text == text assert message.texts == [text] + assert not message.media + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + +def test_from_user_with_media(): + text = "This is a multimodal message!" + media = [ByteStream(data=b"media data", mime_type="image/png")] + message = ChatMessage.from_user(text=text, media=media) + + assert message.role == ChatRole.USER + assert message._content == [TextContent(text="This is a multimodal message!"), MediaContent(media[0])] + + assert message.text == text + assert message.texts == [text] + assert message.media == media + assert not message.tool_calls assert not message.tool_call assert not message.tool_call_results @@ -80,6 +108,7 @@ def test_from_system_with_valid_content(): assert message.text == text assert message.texts == [text] + assert not message.media assert not message.tool_calls assert not message.tool_call assert not message.tool_call_results @@ -98,6 +127,7 @@ def test_from_tool_with_valid_content(): assert message.tool_call_result == tcr assert message.tool_call_results == [tcr] + assert not message.media assert not message.tool_calls assert not message.tool_call assert not message.texts @@ -131,19 +161,36 @@ def test_serde(): role = ChatRole.ASSISTANT text_content = TextContent(text="Hello") + media_content = MediaContent(media=ByteStream(data=b"media_data", mime_type="image/png")) tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False) meta = {"some": "info"} - message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta) + message = ChatMessage( + _role=role, + _content=[text_content, media_content, tool_call, tool_call_result], + _name="my_message", + _meta=meta, + ) serialized_message = message.to_dict() - assert serialized_message == {"_content": - [{"text": "Hello"}, - {"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}, - {"tool_call_result": {"result": "result", "error":False, - "origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}}], - "_role": "assistant", "_meta": {"some": "info"}} + assert serialized_message == { + "_content": [ + {"text": "Hello"}, + {"media": {"data": b64encode(b"media_data").decode(), "meta": {}, "mime_type": "image/png"}}, + {"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}, + { + "tool_call_result": { + "result": "result", + "error": False, + "origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}, + } + }, + ], + "_role": "assistant", + "_name": "my_message", + "_meta": {"some": "info"}, + } deserialized_message = ChatMessage.from_dict(serialized_message) assert deserialized_message == message