Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for multimodal on ChatMessages with ContentPart #7913

Closed
wants to merge 8 commits into from
112 changes: 101 additions & 11 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

from jinja2 import Template, meta

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.content_part import ContentPart, ContentType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,10 +130,48 @@ def __init__(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
msg_template = Template(message.content)
ast = msg_template.environment.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
if isinstance(message.content, str):
msg_template = Template(message.content)
ast = msg_template.environment.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
elif isinstance(message.content, ContentPart):
if message.content.type is ContentType.IMAGE_BASE64:
content = message.content.content.to_string()
else:
content = message.content.content
msg_template = Template(content)
ast = msg_template.environment.parse(content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
elif isinstance(message.content, list):
for part in message.content:
if isinstance(part, str):
part_template = Template(part)
ast = part_template.environment.parse(part)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
elif isinstance(part, ContentPart):
if part.type is ContentType.IMAGE_BASE64:
content = part.content.to_string()
else:
content = part.content
part_template = Template(content)
ast = part_template.environment.parse(content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
else:
raise ValueError(
"One of the elements of the content of one of the ChatMessages \
is not of a valid type."
"Valid types: str or ContentPart. Element: {part}"
)
else:
raise ValueError(
"The content of one of the messages in the template is not of a valid type."
"Valid types: str, ContentPart or list of str and ContentPart."
"Content: {self.content}"
)

# setup inputs
static_input_slots = {"template": Optional[str], "template_variables": Optional[Dict[str, Any]]}
Expand Down Expand Up @@ -194,13 +235,62 @@ def run(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))
compiled_template = Template(message.content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = (
ChatMessage.from_user(rendered_content)
if message.is_from(ChatRole.USER)
else ChatMessage.from_system(rendered_content)
)
if isinstance(message.content, str):
compiled_template = Template(message.content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = (
ChatMessage.from_user(rendered_content)
if message.is_from(ChatRole.USER)
else ChatMessage.from_system(rendered_content)
)

elif isinstance(message.content, ContentPart):
content = message.content.content
if isinstance(content, str):
compiled_template = Template(content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = deepcopy(message)
rendered_message.content.content = rendered_content
else: # ByteStream
compiled_template = Template(content.to_string())
rendered_content = ByteStream.from_string(compiled_template.render(template_variables_combined))
rendered_message = deepcopy(message)
rendered_message.content.content = rendered_content

elif isinstance(message.content, list):
rendered_parts = []
for part in message.content:
if isinstance(part, str):
compiled_template = Template(part)
rendered_part = compiled_template.render(template_variables_combined)
elif isinstance(part, ContentPart):
rendered_part = deepcopy(part)
if isinstance(part.content, str):
compiled_template = Template(part.content)
rendered_part.content = compiled_template.render(template_variables_combined)
else: # ByteStream
compiled_template = Template(part.content.to_string())
rendered_part.content = ByteStream.from_string(
compiled_template.render(template_variables_combined)
)
else:
raise ValueError(
"One of the elements of the content of one of the ChatMessages \
is not of a valid type."
"Valid types: str or ContentPart. Element: {part}"
)
rendered_parts.append(rendered_part)

rendered_message = deepcopy(message)
rendered_message.content = rendered_parts

else:
raise ValueError(
"The content of one of the messages in the template is not of a valid type."
"Valid types: str, ContentPart or list of str and ContentPart."
"Content: {self.content}"
)

processed_messages.append(rendered_message)
else:
processed_messages.append(message)
Expand Down
4 changes: 4 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.content_part import ContentPart, ContentType, ImageDetail
from haystack.dataclasses.document import Document
from haystack.dataclasses.sparse_embedding import SparseEmbedding
from haystack.dataclasses.streaming_chunk import StreamingChunk
Expand All @@ -19,4 +20,7 @@
"ChatRole",
"StreamingChunk",
"SparseEmbedding",
"ContentPart",
"ContentType",
"ImageDetail",
]
2 changes: 2 additions & 0 deletions haystack/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ def to_string(self, encoding: str = "utf-8") -> str:
:returns: The string representation of the ByteStream.
:raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
"""
if isinstance(self.data, str):
return self.data
return self.data.decode(encoding)
79 changes: 71 additions & 8 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union

from .byte_stream import ByteStream
from .content_part import ContentPart


class ChatRole(str, Enum):
Expand All @@ -21,13 +24,13 @@ class ChatMessage:
"""
Represents a message in a LLM chat conversation.

:param content: The text content of the message.
:param content: The content of the message.
:param role: The role of the entity sending the message.
:param name: The name of the function being called (only applicable for role FUNCTION).
:param meta: Additional metadata associated with the message.
"""

content: str
content: Union[str, ContentPart, List[Union[str, ContentPart]]]
role: ChatRole
name: Optional[str]
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
Expand All @@ -43,10 +46,33 @@ def to_openai_format(self) -> Dict[str, Any]:
- `content`
- `name` (optional)
"""
msg = {"role": self.role.value, "content": self.content}
msg = {"role": self.role.value}
if self.name:
msg["name"] = self.name

if isinstance(self.content, str):
msg["content"] = self.content
elif isinstance(self.content, ContentPart):
msg["content"] = self.content.to_openai_format()
elif isinstance(self.content, list):
msg["content"] = []
for part in self.content:
if isinstance(part, str):
msg["content"].append(ContentPart.from_text(part).to_openai_format())
elif isinstance(part, ContentPart):
msg["content"].append(part.to_openai_format())
else:
raise ValueError(
"One of the elements of the content is not of a valid type."
"Valid types: str or ContentPart. Element: {part}"
)
else:
raise ValueError(
"The content of the message is not of a valid type."
"Valid types: str, ContentPart or list of str and ContentPart."
"Content: {self.content}"
)

return msg

def is_from(self, role: ChatRole) -> bool:
Expand All @@ -59,7 +85,9 @@ def is_from(self, role: ChatRole) -> bool:
return self.role == role

@classmethod
def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage":
def from_assistant(
cls, content: Union[str, ContentPart, List[Union[str, ContentPart]]], meta: Optional[Dict[str, Any]] = None
) -> "ChatMessage":
"""
Create a message from the assistant.

Expand All @@ -70,7 +98,7 @@ def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) ->
return cls(content, ChatRole.ASSISTANT, None, meta or {})

@classmethod
def from_user(cls, content: str) -> "ChatMessage":
def from_user(cls, content: Union[str, ContentPart, List[Union[str, ContentPart]]]) -> "ChatMessage":
"""
Create a message from the user.

Expand All @@ -80,7 +108,7 @@ def from_user(cls, content: str) -> "ChatMessage":
return cls(content, ChatRole.USER, None)

@classmethod
def from_system(cls, content: str) -> "ChatMessage":
def from_system(cls, content: Union[str, ContentPart, List[Union[str, ContentPart]]]) -> "ChatMessage":
"""
Create a message from the system.

Expand All @@ -90,7 +118,7 @@ def from_system(cls, content: str) -> "ChatMessage":
return cls(content, ChatRole.SYSTEM, None)

@classmethod
def from_function(cls, content: str, name: str) -> "ChatMessage":
def from_function(cls, content: Union[str, ContentPart, List[Union[str, ContentPart]]], name: str) -> "ChatMessage":
"""
Create a message from a function call.

Expand All @@ -110,6 +138,29 @@ def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data["role"] = self.role.value

if isinstance(self.content, str):
data["content"] = self.content
elif isinstance(self.content, ContentPart):
data["content"] = self.content.to_dict()
elif isinstance(self.content, list):
data["content"] = []
for part in self.content:
if isinstance(part, str):
data["content"].append(part)
elif isinstance(part, ContentPart):
data["content"].append(part.to_dict())
else:
raise ValueError(
"One of the elements of the content is not of a valid type."
"Valid types: str or ContentPart. Element: {part}"
)
else:
raise ValueError(
"The content of the message is not of a valid type."
"Valid types: str, ContentPart or list of str and ContentPart."
"Content: {self.content}"
)

return data

@classmethod
Expand All @@ -124,4 +175,16 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage":
"""
data["role"] = ChatRole(data["role"])

if "content" in data:
if isinstance(data["content"], dict): # Assume it is a ContentPart
data["content"] = ContentPart.from_dict(data["content"])
elif isinstance(data["content"], list):
content = data.pop("content")
data["content"] = []
for part in content:
if isinstance(part, str):
data["content"].append(part)
else:
data["content"].append(ContentPart.from_dict(part))

return cls(**data)
Loading
Loading