From 113c49e2d6c79b31a7ca0e5a8c3509f935248b61 Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Fri, 6 Sep 2024 23:38:42 +0100 Subject: [PATCH 1/2] Feat: Implement Lazy attribute loading for the image data --- mbodied/types/sense/vision.py | 198 +++++++++++++++------------------- tests/test_senses.py | 21 ++++ 2 files changed, 106 insertions(+), 113 deletions(-) diff --git a/mbodied/types/sense/vision.py b/mbodied/types/sense/vision.py index 288d43f4..0ef21679 100644 --- a/mbodied/types/sense/vision.py +++ b/mbodied/types/sense/vision.py @@ -26,8 +26,6 @@ ```python image = Image("path/to/image.png", size=new_size_tuple).save("path/to/new/image.jpg") image.save("path/to/new/image.jpg", quality=5) - -TODO: Implement Lazy attribute loading for the image data. """ import base64 as base64lib @@ -52,6 +50,7 @@ InstanceOf, model_serializer, model_validator, + PrivateAttr ) from typing_extensions import Literal @@ -90,18 +89,12 @@ class Image(Sample): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True, extras="forbid", validate_assignment=False) - array: NumpyArray - size: tuple[int, int] - - pil: InstanceOf[PILImage] | None = Field( - None, - repr=False, - exclude=True, - description="The image represented as a PIL Image object.", - ) + _array: NumpyArray | None = PrivateAttr(default=None) + _base64: InstanceOf[Base64Str] | None = PrivateAttr(default=None) + _pil: InstanceOf[PILImage] | None = PrivateAttr(default=None) + _url: InstanceOf[AnyUrl] | str | None = PrivateAttr(default=None) + _size: tuple[int, int] | None = PrivateAttr(default=None) encoding: Literal["png", "jpeg", "jpg", "bmp", "gif"] - base64: InstanceOf[Base64Str] | None = None - url: InstanceOf[AnyUrl] | str | None = None path: FilePath | None = None @classmethod @@ -176,6 +169,75 @@ def __init__( kwargs["bytes"] = bytes_obj super().__init__(**kwargs) + self._array = kwargs.get("array", None) + self._base64 = kwargs.get("base64", None) + self._pil = kwargs.get("pil", None) + self._url = kwargs.get("url", None) + + @property + def array(self) -> np.ndarray | None: + """Lazily computes and returns the NumPy array.""" + if self._array is None: + if self._pil is not None: + # Convert the PIL image to a NumPy array + self._array = np.array(self._pil) + else: + self._array = np.array(self.pil) + return self._array + + @property + def base64(self) -> Base64Str | None: + """Lazily computes and returns the base64-encoded string.""" + if self._base64 is None: + buffer = io.BytesIO() + # Save the PIL image to a buffer in the specified encoding + if self._pil is not None: + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) + else: + self.pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") + return self._base64 + + @property + def pil(self) -> PILImage | None: + """Lazily loads and returns the PIL image.""" + if self._pil is None: + if self._array is not None: + self._pil = PILModule.fromarray(self._array).convert("RGB") + elif self._base64 is not None: + image_data = base64lib.b64decode(self._base64) + self._pil = PILModule.open(io.BytesIO(image_data)).convert("RGB") + elif self.path is not None: + self._pil = PILModule.open(self.path).convert("RGB") + elif self._url is not None: + self._pil = Image.load_url(self._url) + return self._pil + + @property + def url(self) -> AnyUrl | str | None: + """Lazily computes and returns the data URL.""" + if self._url is None: + # First convert the PIL image to a base64 string + buffer = io.BytesIO() + if self._pil is not None: + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) + else: + self.pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") + # Construct the data URL + self._url = f"data:image/{self.encoding};base64,{self._base64}" + return self._url + + @property + def size(self) -> tuple[int, int] | None: + "Lazily computes and returns the image size" + if self._size is None: + if self._pil is not None: + self._size = self._pil.size + else: + self._size = self.pil.size + return self._size + def __repr__(self): """Return a string representation of the image.""" if self.base64 is None: @@ -217,37 +279,6 @@ def open(path: str, encoding: str = "jpeg", size=None) -> "Image": image = PILModule.open(path).convert("RGB") return Image(image, encoding, size) - @staticmethod - def pil_to_data(image: PILImage, encoding: str, size=None) -> dict: - """Creates an Image instance from a PIL image. - - Args: - image (PIL.Image.Image): The source PIL image from which to create the Image instance. - encoding (str): The format used for encoding the image when converting to base64. - size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple. - - Returns: - Image: An instance of the Image class with populated fields. - """ - if encoding.lower() == "jpg": - encoding = "jpeg" - buffer = io.BytesIO() - image.convert("RGB").save(buffer, format=encoding.upper()) - base64_encoded = base64lib.b64encode(buffer.getvalue()).decode("utf-8") - data_url = f"data:image/{encoding};base64,{base64_encoded}" - if size is not None: - image = image.resize(size) - else: - size = image.size - return { - "array": np.array(image), - "base64": base64_encoded, - "pil": image, - "size": size, - "url": data_url, - "encoding": encoding.lower(), - } - @staticmethod def load_url(url: str, download=False) -> PILImage | None: """Downloads an image from a URL or decodes it from a base64 data URI. @@ -302,21 +333,6 @@ def from_bytes(cls, bytes_data: bytes, encoding: str = "jpeg", size=None) -> "Im image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB") return cls(image, encoding, size) - @staticmethod - def bytes_to_data(bytes_data: bytes, encoding: str = "jpeg", size=None) -> dict: - """Creates an Image instance from a bytes object. - - Args: - bytes_data (bytes): The bytes object to convert to an image. - encoding (str): The format used for encoding the image when converting to base64. - size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple. - - Returns: - Image: An instance of the Image class with populated fields. - """ - image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB") - return Image.pil_to_data(image, encoding, size) - @model_validator(mode="before") @classmethod def validate_kwargs(cls, values) -> dict: @@ -327,74 +343,30 @@ def validate_kwargs(cls, values) -> dict: if len(provided_fields) > 1: raise ValueError(f"Multiple image sources provided; only one is allowed but got: {provided_fields}") - # Initialize all fields to None or their default values + # Initialize all fields to their input values or None validated_values = { - "array": None, - "base64": None, + "array": values.get("array", None), + "base64": values.get("base64", None), "encoding": values.get("encoding", "jpeg").lower(), - "path": None, - "pil": None, - "url": None, + "path": values.get("path", None), + "pil": values.get("pil", None), + "url": values.get("url", None), "size": values.get("size", None), } # Validate the encoding first + if validated_values["encoding"] == "jpg": + validated_values["encoding"] = "jpeg" + if validated_values["encoding"] not in ["png", "jpeg", "jpg", "bmp", "gif"]: raise ValueError("The 'encoding' must be a valid image format (png, jpeg, jpg, bmp, gif).") - if "bytes" in values and values["bytes"] is not None: - validated_values.update(cls.bytes_to_data(values["bytes"], values["encoding"], values["size"])) - return validated_values - - if "pil" in values and values["pil"] is not None: - validated_values.update( - cls.pil_to_data(values["pil"], values["encoding"], values["size"]), - ) - return validated_values - # Process the provided image source - if "path" in provided_fields: - image = PILModule.open(values["path"]).convert("RGB") - validated_values["path"] = values["path"] - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - - elif "array" in provided_fields: - image = PILModule.fromarray(values["array"]).convert("RGB") - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - - elif "pil" in provided_fields: - validated_values.update( - cls.pil_to_data(values["pil"], validated_values["encoding"], validated_values["size"]), - ) - - elif "base64" in provided_fields: - validated_values.update( - cls.from_base64(values["base64"], validated_values["encoding"], validated_values["size"]), - ) - - elif "url" in provided_fields: + if "url" in provided_fields: url_path = urlparse(values["url"]).path file_extension = ( Path(url_path).suffix[1:].lower() if Path(url_path).suffix else validated_values["encoding"] ) validated_values["encoding"] = file_extension - validated_values["url"] = values["url"] - image = cls.load_url(values["url"]) - if image is None: - validated_values["array"] = np.zeros((224, 224, 3), dtype=np.uint8) - validated_values["size"] = (224, 224) - return validated_values - - validated_values.update(cls.pil_to_data(image, file_extension, validated_values["size"])) - validated_values["url"] = values["url"] - - elif "size" in values and values["size"] is not None: - array = np.zeros((values["size"][0], values["size"][1], 3), dtype=np.uint8) - image = PILModule.fromarray(array).convert("RGB") - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - if any(validated_values[k] is None for k in ["array", "base64", "pil", "url"]): - logging.warning( - f"Failed to validate image data. Could only fetch {[k for k in validated_values if validated_values[k] is not None]}", - ) return validated_values def save(self, path: str, encoding: str | None = None, quality: int = 10) -> None: diff --git a/tests/test_senses.py b/tests/test_senses.py index 41248a6c..c299262f 100644 --- a/tests/test_senses.py +++ b/tests/test_senses.py @@ -139,6 +139,27 @@ def test_image_model_dump_load_with_base64(): reconstructed_img = Image.model_validate_json(json) assert np.array_equal(reconstructed_img.array, array) + +def test_lazy_loading(): + # Create an image with a path + image_path = "resources/bridge_example.jpeg" + img = Image(image_path) + + # Test that attributes are lazily loaded + assert img._array is None + assert img._base64 is None + assert img._size is None + + # Access size, which should trigger lazy loading + assert img.size is not None + assert img._size is not None + + # Access array and base64 to ensure they are also lazily loaded + assert img.array is not None + assert img._array is not None + assert img.base64 is not None + assert img._base64 is not None + if __name__ == "__main__": pytest.main([__file__, "-vv"]) From 2f71dcf7a3beb18ec952b518d2f139d8eae50f13 Mon Sep 17 00:00:00 2001 From: emekaokoli19 Date: Sat, 7 Sep 2024 15:25:06 +0100 Subject: [PATCH 2/2] fix tests --- mbodied/types/sense/vision.py | 37 +++++++++++++++-------------------- tests/test_senses.py | 3 +++ 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/mbodied/types/sense/vision.py b/mbodied/types/sense/vision.py index 0ef21679..c371a14d 100644 --- a/mbodied/types/sense/vision.py +++ b/mbodied/types/sense/vision.py @@ -173,28 +173,27 @@ def __init__( self._base64 = kwargs.get("base64", None) self._pil = kwargs.get("pil", None) self._url = kwargs.get("url", None) + self._size = kwargs.get("size", None) + + if self._size is not None and self.pil is not None: + self._pil = self._pil.resize(self._size) + self._array = np.array(self._pil) @property def array(self) -> np.ndarray | None: """Lazily computes and returns the NumPy array.""" - if self._array is None: - if self._pil is not None: - # Convert the PIL image to a NumPy array - self._array = np.array(self._pil) - else: - self._array = np.array(self.pil) + if self._array is None and self.pil is not None: + # Convert the PIL image to a NumPy array + self._array = np.array(self._pil) return self._array @property def base64(self) -> Base64Str | None: """Lazily computes and returns the base64-encoded string.""" - if self._base64 is None: + if self._base64 is None and self.pil is not None: buffer = io.BytesIO() # Save the PIL image to a buffer in the specified encoding - if self._pil is not None: - self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) - else: - self.pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") return self._base64 @@ -216,13 +215,12 @@ def pil(self) -> PILImage | None: @property def url(self) -> AnyUrl | str | None: """Lazily computes and returns the data URL.""" - if self._url is None: + if self._url is None and self._base64 is not None: + self._url = f"data:image/{self.encoding};base64,{self._base64}" + elif self._url is None and self.pil is not None: # First convert the PIL image to a base64 string buffer = io.BytesIO() - if self._pil is not None: - self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) - else: - self.pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") # Construct the data URL self._url = f"data:image/{self.encoding};base64,{self._base64}" @@ -231,11 +229,8 @@ def url(self) -> AnyUrl | str | None: @property def size(self) -> tuple[int, int] | None: "Lazily computes and returns the image size" - if self._size is None: - if self._pil is not None: - self._size = self._pil.size - else: - self._size = self.pil.size + if self._size is None and self.pil is not None: + self._size = self._pil.size return self._size def __repr__(self): diff --git a/tests/test_senses.py b/tests/test_senses.py index c299262f..50e5e91a 100644 --- a/tests/test_senses.py +++ b/tests/test_senses.py @@ -149,6 +149,7 @@ def test_lazy_loading(): assert img._array is None assert img._base64 is None assert img._size is None + assert img._url is None # Access size, which should trigger lazy loading assert img.size is not None @@ -159,6 +160,8 @@ def test_lazy_loading(): assert img._array is not None assert img.base64 is not None assert img._base64 is not None + assert img.url is not None + assert img._url is not None if __name__ == "__main__":