-
Notifications
You must be signed in to change notification settings - Fork 20
/
serializer.py
182 lines (142 loc) · 5.56 KB
/
serializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright 2024 mbodi ai
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from pydantic import ConfigDict, model_serializer, model_validator
from mbodied.types.message import Message
from mbodied.types.sample import Sample
from mbodied.types.sense.vision import Image
class Serializer(Sample):
"""A class to serialize messages and samples.
This class provides a mechanism to serialize messages and samples into a dictionary format
used by i.e. OpenAI, Anthropic, or other APIs.
Attributes:
wrapped: The message or sample to be serialized.
model_config: The Pydantic configuration for the Serializer model.
"""
wrapped: Any | None = None
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
def __init__(
self,
wrapped: Message | Sample | list[Message] | None = None,
*,
message: Message | None = None,
sample: Sample | None = None,
**data,
):
"""Initializes the Serializer with various possible wrapped types.
Args:
wrapped: An instance of Message, Sample, a list of Messages, or None.
message: An optional Message to be wrapped.
sample: An optional Sample to be wrapped.
**data: Additional data to initialize the Sample base class.
"""
if wrapped is not None:
data["wrapped"] = wrapped
elif message is not None:
data["wrapped"] = message
elif sample is not None:
data["wrapped"] = sample
super().__init__(**data)
@model_validator(mode="before")
@classmethod
def validate_model(cls, values: dict[str, Any]) -> dict[str, Any] | list[Any]:
"""Validates the 'wrapped' field of the model.
Args:
values: A dictionary of field values to validate.
Returns:
The validated values dictionary.
Raises:
ValueError: If the 'wrapped' field contains an invalid type.
"""
if (
"wrapped" in values
and values["wrapped"] is not None
and not isinstance(
values["wrapped"],
Message | Sample | list | str | Image,
)
):
raise ValueError(
f"Invalid wrapped type {type(values['wrapped'])}",
)
return values
def serialize_sample(self, sample: Any) -> dict[str, Any]:
"""Serializes a given sample.
Args:
sample: The sample to be serialized.
Returns:
A dictionary representing the serialized sample.
"""
if isinstance(sample, Message):
return self.serialize_msg(sample)
if not isinstance(sample, Sample):
sample = Sample(sample)
if isinstance(sample, Image):
return self.serialize_image(sample)
if Image.supports(sample):
return self.serialize_image(Image(sample))
if hasattr(sample, "datum") and isinstance(sample.datum, str):
return self.serialize_text(sample.datum)
return self.serialize_text(str(sample))
@model_serializer(when_used="always")
def serialize(self) -> dict[str, Any] | list[Any]:
"""Serializes the wrapped content of the Serializer instance.
Returns:
A dictionary representing the serialized wrapped content.
"""
if isinstance(self.wrapped, Message):
return self.serialize_msg(self.wrapped)
if isinstance(self.wrapped, list):
if all(isinstance(m, Message) for m in self.wrapped):
return [self.serialize_msg(m) for m in self.wrapped]
return [self.serialize_sample(m) for m in self.wrapped]
return self.serialize_sample(self.wrapped)
def serialize_msg(self, message: Message) -> dict[str, Any]:
"""Serializes a Message instance.
Args:
message: The Message to be serialized.
Returns:
A dictionary representing the serialized Message.
"""
return {
"role": message.role,
"content": [self.serialize_sample(c) for c in message.content],
}
@classmethod
def serialize_image(cls, image: Image) -> dict[str, Any]:
"""Serializes an Image instance.
Args:
image: The Image to be serialized.
Returns:
A dictionary representing the serialized Image.
"""
return {
"type": "image",
"image_url": f"data:image/{image.encoding};base64," + image.base64,
}
@classmethod
def serialize_text(cls, text: str) -> dict[str, Any]:
"""Serializes a text string.
Args:
text: The text to be serialized.
Returns:
A dictionary representing the serialized text.
"""
return {"type": "text", "text": text}
def __call__(self) -> dict[str, Any] | list[Any]:
"""Calls the serialize method.
Returns:
A dictionary representing the serialized wrapped content.
"""
return self.model_dump()