Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement Lazy attribute loading for the image data #86

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 80 additions & 113 deletions mbodied/types/sense/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
```python
image = Image("path/to/image.png", size=new_size_tuple).save("path/to/new/image.jpg")
image.save("path/to/new/image.jpg", quality=5)

TODO: Implement Lazy attribute loading for the image data.
"""

import base64 as base64lib
Expand All @@ -52,6 +50,7 @@
InstanceOf,
model_serializer,
model_validator,
PrivateAttr
)
from typing_extensions import Literal

Expand Down Expand Up @@ -90,18 +89,12 @@ class Image(Sample):

model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True, extras="forbid", validate_assignment=False)

array: NumpyArray
size: tuple[int, int]

pil: InstanceOf[PILImage] | None = Field(
None,
repr=False,
exclude=True,
description="The image represented as a PIL Image object.",
)
_array: NumpyArray | None = PrivateAttr(default=None)
_base64: InstanceOf[Base64Str] | None = PrivateAttr(default=None)
_pil: InstanceOf[PILImage] | None = PrivateAttr(default=None)
_url: InstanceOf[AnyUrl] | str | None = PrivateAttr(default=None)
_size: tuple[int, int] | None = PrivateAttr(default=None)
encoding: Literal["png", "jpeg", "jpg", "bmp", "gif"]
base64: InstanceOf[Base64Str] | None = None
url: InstanceOf[AnyUrl] | str | None = None
path: FilePath | None = None

@classmethod
Expand Down Expand Up @@ -176,6 +169,70 @@ def __init__(
kwargs["bytes"] = bytes_obj
super().__init__(**kwargs)

self._array = kwargs.get("array", None)
self._base64 = kwargs.get("base64", None)
self._pil = kwargs.get("pil", None)
self._url = kwargs.get("url", None)
self._size = kwargs.get("size", None)

if self._size is not None and self.pil is not None:
self._pil = self._pil.resize(self._size)
self._array = np.array(self._pil)

@property
def array(self) -> np.ndarray | None:
"""Lazily computes and returns the NumPy array."""
if self._array is None and self.pil is not None:
# Convert the PIL image to a NumPy array
self._array = np.array(self._pil)
return self._array

@property
def base64(self) -> Base64Str | None:
"""Lazily computes and returns the base64-encoded string."""
if self._base64 is None and self.pil is not None:
buffer = io.BytesIO()
# Save the PIL image to a buffer in the specified encoding
self._pil.convert("RGB").save(buffer, format=self.encoding.upper())
self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8")
return self._base64

@property
def pil(self) -> PILImage | None:
"""Lazily loads and returns the PIL image."""
if self._pil is None:
if self._array is not None:
self._pil = PILModule.fromarray(self._array).convert("RGB")
elif self._base64 is not None:
image_data = base64lib.b64decode(self._base64)
self._pil = PILModule.open(io.BytesIO(image_data)).convert("RGB")
elif self.path is not None:
self._pil = PILModule.open(self.path).convert("RGB")
elif self._url is not None:
self._pil = Image.load_url(self._url)
return self._pil

@property
def url(self) -> AnyUrl | str | None:
"""Lazily computes and returns the data URL."""
if self._url is None and self._base64 is not None:
self._url = f"data:image/{self.encoding};base64,{self._base64}"
elif self._url is None and self.pil is not None:
# First convert the PIL image to a base64 string
buffer = io.BytesIO()
self._pil.convert("RGB").save(buffer, format=self.encoding.upper())
self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8")
# Construct the data URL
self._url = f"data:image/{self.encoding};base64,{self._base64}"
return self._url

@property
def size(self) -> tuple[int, int] | None:
"Lazily computes and returns the image size"
if self._size is None and self.pil is not None:
self._size = self._pil.size
return self._size

def __repr__(self):
"""Return a string representation of the image."""
if self.base64 is None:
Expand Down Expand Up @@ -217,37 +274,6 @@ def open(path: str, encoding: str = "jpeg", size=None) -> "Image":
image = PILModule.open(path).convert("RGB")
return Image(image, encoding, size)

@staticmethod
def pil_to_data(image: PILImage, encoding: str, size=None) -> dict:
"""Creates an Image instance from a PIL image.

Args:
image (PIL.Image.Image): The source PIL image from which to create the Image instance.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.

Returns:
Image: An instance of the Image class with populated fields.
"""
if encoding.lower() == "jpg":
encoding = "jpeg"
buffer = io.BytesIO()
image.convert("RGB").save(buffer, format=encoding.upper())
base64_encoded = base64lib.b64encode(buffer.getvalue()).decode("utf-8")
data_url = f"data:image/{encoding};base64,{base64_encoded}"
if size is not None:
image = image.resize(size)
else:
size = image.size
return {
"array": np.array(image),
"base64": base64_encoded,
"pil": image,
"size": size,
"url": data_url,
"encoding": encoding.lower(),
}

@staticmethod
def load_url(url: str, download=False) -> PILImage | None:
"""Downloads an image from a URL or decodes it from a base64 data URI.
Expand Down Expand Up @@ -302,21 +328,6 @@ def from_bytes(cls, bytes_data: bytes, encoding: str = "jpeg", size=None) -> "Im
image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB")
return cls(image, encoding, size)

@staticmethod
def bytes_to_data(bytes_data: bytes, encoding: str = "jpeg", size=None) -> dict:
"""Creates an Image instance from a bytes object.

Args:
bytes_data (bytes): The bytes object to convert to an image.
encoding (str): The format used for encoding the image when converting to base64.
size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple.

Returns:
Image: An instance of the Image class with populated fields.
"""
image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB")
return Image.pil_to_data(image, encoding, size)

@model_validator(mode="before")
@classmethod
def validate_kwargs(cls, values) -> dict:
Expand All @@ -327,74 +338,30 @@ def validate_kwargs(cls, values) -> dict:
if len(provided_fields) > 1:
raise ValueError(f"Multiple image sources provided; only one is allowed but got: {provided_fields}")

# Initialize all fields to None or their default values
# Initialize all fields to their input values or None
validated_values = {
"array": None,
"base64": None,
"array": values.get("array", None),
"base64": values.get("base64", None),
"encoding": values.get("encoding", "jpeg").lower(),
"path": None,
"pil": None,
"url": None,
"path": values.get("path", None),
"pil": values.get("pil", None),
"url": values.get("url", None),
"size": values.get("size", None),
}

# Validate the encoding first
if validated_values["encoding"] == "jpg":
validated_values["encoding"] = "jpeg"

if validated_values["encoding"] not in ["png", "jpeg", "jpg", "bmp", "gif"]:
raise ValueError("The 'encoding' must be a valid image format (png, jpeg, jpg, bmp, gif).")

if "bytes" in values and values["bytes"] is not None:
validated_values.update(cls.bytes_to_data(values["bytes"], values["encoding"], values["size"]))
return validated_values

if "pil" in values and values["pil"] is not None:
validated_values.update(
cls.pil_to_data(values["pil"], values["encoding"], values["size"]),
)
return validated_values
# Process the provided image source
if "path" in provided_fields:
image = PILModule.open(values["path"]).convert("RGB")
validated_values["path"] = values["path"]
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))

elif "array" in provided_fields:
image = PILModule.fromarray(values["array"]).convert("RGB")
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))

elif "pil" in provided_fields:
validated_values.update(
cls.pil_to_data(values["pil"], validated_values["encoding"], validated_values["size"]),
)

elif "base64" in provided_fields:
validated_values.update(
cls.from_base64(values["base64"], validated_values["encoding"], validated_values["size"]),
)

elif "url" in provided_fields:
if "url" in provided_fields:
url_path = urlparse(values["url"]).path
file_extension = (
Path(url_path).suffix[1:].lower() if Path(url_path).suffix else validated_values["encoding"]
)
validated_values["encoding"] = file_extension
validated_values["url"] = values["url"]
image = cls.load_url(values["url"])
if image is None:
validated_values["array"] = np.zeros((224, 224, 3), dtype=np.uint8)
validated_values["size"] = (224, 224)
return validated_values

validated_values.update(cls.pil_to_data(image, file_extension, validated_values["size"]))
validated_values["url"] = values["url"]

elif "size" in values and values["size"] is not None:
array = np.zeros((values["size"][0], values["size"][1], 3), dtype=np.uint8)
image = PILModule.fromarray(array).convert("RGB")
validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"]))
if any(validated_values[k] is None for k in ["array", "base64", "pil", "url"]):
logging.warning(
f"Failed to validate image data. Could only fetch {[k for k in validated_values if validated_values[k] is not None]}",
)
return validated_values

def save(self, path: str, encoding: str | None = None, quality: int = 10) -> None:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_senses.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,30 @@ def test_image_model_dump_load_with_base64():
reconstructed_img = Image.model_validate_json(json)
assert np.array_equal(reconstructed_img.array, array)


def test_lazy_loading():
# Create an image with a path
image_path = "resources/bridge_example.jpeg"
img = Image(image_path)

# Test that attributes are lazily loaded
assert img._array is None
assert img._base64 is None
assert img._size is None
assert img._url is None

# Access size, which should trigger lazy loading
assert img.size is not None
assert img._size is not None

# Access array and base64 to ensure they are also lazily loaded
assert img.array is not None
assert img._array is not None
assert img.base64 is not None
assert img._base64 is not None
assert img.url is not None
assert img._url is not None


if __name__ == "__main__":
pytest.main([__file__, "-vv"])
Loading