diff --git a/src/marqo/api/configs.py b/src/marqo/api/configs.py index e8bd7b23c..48d9a9204 100644 --- a/src/marqo/api/configs.py +++ b/src/marqo/api/configs.py @@ -31,6 +31,7 @@ def default_env_vars() -> dict: EnvVars.MARQO_LOG_LEVEL: "info", EnvVars.MARQO_MEDIA_DOWNLOAD_THREAD_COUNT_PER_REQUEST: 5, EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST: 20, + EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE: 370 * 1024 * 1024, # 370 megabytes in bytes # This env variable is set to "info" by default in run_marqo.sh, which overrides this value EnvVars.MARQO_MAX_CPU_MODEL_MEMORY: 4, EnvVars.MARQO_MAX_CUDA_MODEL_MEMORY: 4, # For multi-GPU, this is the max memory for each GPU. diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py index 46d9aa7c1..eeb23a44f 100644 --- a/src/marqo/s2_inference/clip_utils.py +++ b/src/marqo/s2_inference/clip_utils.py @@ -21,6 +21,7 @@ from marqo import marqo_docs from marqo.api.exceptions import InternalError +from marqo.tensor_search.enums import EnvVars from marqo.core.inference.models.abstract_clip_model import AbstractCLIPModel from marqo.core.inference.models.open_clip_model_properties import OpenCLIPModelProperties, ImagePreprocessor from marqo.s2_inference.configs import ModelCache @@ -150,7 +151,7 @@ def validate_url(url: str) -> bool: -def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO: +def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000, modality: Optional[str] = None) -> BytesIO: """Download an image from a URL and return a PIL image using pycurl. Args: @@ -171,7 +172,7 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo try: encoded_url = encode_url(image_path) except UnicodeEncodeError as e: - raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. " + raise ImageDownloadError(f"Marqo encountered an error when downloading the media url {image_path}. " f"The url could not be encoded properly. Original error: {e}") buffer = BytesIO() c = pycurl.Curl() @@ -185,15 +186,30 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo headers.update(image_download_headers) c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()]) + # callback to check file size for video and audio + if modality in ["video", "audio"]: + max_size = EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE + def progress(download_total, downloaded, upload_total, uploaded): + if downloaded > max_size: + return 1 # Abort the download + c.setopt(pycurl.NOPROGRESS, False) + c.setopt(pycurl.XFERINFOFUNCTION, progress) + try: c.perform() if c.getinfo(pycurl.RESPONSE_CODE) != 200: - raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}") + raise ImageDownloadError(f"media url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}") except pycurl.error as e: - raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. " - f"The original error is: {e}") + error_message = str(e) + if len(e.args) > 0: + error_code = e.args[0] + if error_code == pycurl.E_ABORTED_BY_CALLBACK: + error_message = f"Media file `{image_path}` exceeds the maximum allowed size for {modality}." + raise ImageDownloadError(f"Marqo encountered an error when downloading the media url {image_path}. " + f"The original error is: {error_message}") finally: c.close() + buffer.seek(0) return buffer diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py index 173630c22..ae44e99e7 100644 --- a/src/marqo/s2_inference/multimodal_model_load.py +++ b/src/marqo/s2_inference/multimodal_model_load.py @@ -278,7 +278,7 @@ def encode(self, content, modality, normalize=True, **kwargs): if isinstance(content, str) and "http" in content: suffix = ".mp4" if modality == Modality.VIDEO else ".wav" with self._temp_file(suffix) as temp_filename: - self._download_content(content, temp_filename) + self._download_content(content, temp_filename, modality) preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt') inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values'] @@ -300,11 +300,11 @@ def encode(self, content, modality, normalize=True, **kwargs): return embeddings.cpu().numpy() - def _download_content(self, url, filename): + def _download_content(self, url, filename, modality): # 3 seconds for images, 20 seconds for audio and video timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000 - buffer = download_image_from_url(url, {}, timeout_ms) + buffer = download_image_from_url(url, {}, timeout_ms, modality) with open(filename, 'wb') as f: - f.write(buffer.getvalue()) + f.write(buffer.getvalue()) \ No newline at end of file diff --git a/src/marqo/tensor_search/enums.py b/src/marqo/tensor_search/enums.py index feec79d17..6be7e69ca 100644 --- a/src/marqo/tensor_search/enums.py +++ b/src/marqo/tensor_search/enums.py @@ -64,6 +64,7 @@ class EnvVars: MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST = "MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST" MARQO_ROOT_PATH = "MARQO_ROOT_PATH" MARQO_MAX_CPU_MODEL_MEMORY = "MARQO_MAX_CPU_MODEL_MEMORY" + MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE = "MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE" MARQO_MAX_CUDA_MODEL_MEMORY = "MARQO_MAX_CUDA_MODEL_MEMORY" MARQO_EF_CONSTRUCTION_MAX_VALUE = "MARQO_EF_CONSTRUCTION_MAX_VALUE" MARQO_MAX_VECTORISE_BATCH_SIZE = "MARQO_MAX_VECTORISE_BATCH_SIZE" diff --git a/tests/s2_inference/test_image_downloading.py b/tests/s2_inference/test_image_downloading.py index 89f88200f..6b575bd3a 100644 --- a/tests/s2_inference/test_image_downloading.py +++ b/tests/s2_inference/test_image_downloading.py @@ -1,5 +1,5 @@ from unittest import TestCase -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, ANY import pycurl import pytest @@ -10,6 +10,8 @@ from marqo.s2_inference.clip_utils import encode_url, download_image_from_url from marqo.s2_inference.errors import ImageDownloadError from tests.marqo_test import MockHttpServer +from marqo.tensor_search.enums import EnvVars +from io import BytesIO @pytest.mark.unittest @@ -90,3 +92,108 @@ def test_download_image_from_url_handlesRedirection(self): with MockHttpServer(app).run_in_thread() as base_url: result = download_image_from_url(f'{base_url}/missing_image.jpg', image_download_headers={}) self.assertEqual(result.getvalue(), image_content) + + @patch('marqo.s2_inference.clip_utils.pycurl.Curl') + @patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit + def test_video_audio_file_size_check_over_limit(self, mock_curl): + # Setup + test_url = "http://ipv4.download.thinkbroadband.com:8080/5GB.zip" + mock_curl_instance = MagicMock() + mock_curl.return_value = mock_curl_instance + + # Store the progress callback + progress_callback = None + + def simulate_setopt(option, value): + nonlocal progress_callback + if option == pycurl.XFERINFOFUNCTION: + progress_callback = value + elif option == pycurl.WRITEFUNCTION: + # Simulate writing some data + value(b'Some data') + + mock_curl_instance.setopt.side_effect = simulate_setopt + + # Simulate pycurl.error with E_ABORTED_BY_CALLBACK + mock_curl_instance.perform.side_effect = pycurl.error(pycurl.E_ABORTED_BY_CALLBACK, "Callback aborted") + + # Test + with self.assertRaises(ImageDownloadError) as context: + download_image_from_url(test_url, {}, modality="video") + + # Simulate the progress callback after download_image_from_url has set it + if progress_callback: + # Simulate downloading more than the limit + progress_callback(0, 6_000_000, 0, 0) + + # Assert + self.assertIn("exceeds the maximum allowed size", str(context.exception)) + mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False) + mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY) + + @patch('marqo.s2_inference.clip_utils.pycurl.Curl') + @patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit + def test_video_audio_file_size_check_under_limit(self, mock_curl): + # Setup + test_url = "http://example.com/small_video.mp4" + mock_curl_instance = MagicMock() + mock_curl.return_value = mock_curl_instance + + # Store the progress callback + progress_callback = None + + def simulate_setopt(option, value): + nonlocal progress_callback + if option == pycurl.XFERINFOFUNCTION: + progress_callback = value + + mock_curl_instance.setopt.side_effect = simulate_setopt + mock_curl_instance.getinfo.return_value = 200 # Simulate successful HTTP response + + # Test + try: + result = download_image_from_url(test_url, {}, modality="audio") + + # Simulate the progress callback after download_image_from_url has set it + if progress_callback: + progress_callback(0, 3_000_000, 0, 0) + + self.assertIsInstance(result, BytesIO) + except ImageDownloadError: + self.fail("ImageDownloadError raised unexpectedly for file under size limit") + + # Assert + mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False) + mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY) + mock_curl_instance.perform.assert_called_once() + + @patch('marqo.s2_inference.clip_utils.pycurl.Curl') + @patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit + def test_image_file_size_not_checked(self, mock_curl): + # Setup + test_url = "http://example.com/large_image.jpg" + mock_curl_instance = MagicMock() + mock_curl.return_value = mock_curl_instance + + # Simulate successful download + mock_curl_instance.getinfo.return_value = 200 + + # Test + try: + result = download_image_from_url(test_url, {}, modality="image") + self.assertIsInstance(result, BytesIO) + except ImageDownloadError: + self.fail("ImageDownloadError raised unexpectedly for image modality") + + # Assert + mock_curl_instance.setopt.assert_any_call(pycurl.URL, test_url) + mock_curl_instance.setopt.assert_any_call(pycurl.WRITEDATA, ANY) + + # Check that NOPROGRESS and XFERINFOFUNCTION are not set for image modality + with self.assertRaises(AssertionError): + mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False) + with self.assertRaises(AssertionError): + mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY) + + # Verify that perform was called + mock_curl_instance.perform.assert_called_once() \ No newline at end of file