Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video/Audio File Size limits in search and embed #1012

Open
wants to merge 5 commits into
base: mainline
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/marqo/api/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def default_env_vars() -> dict:
EnvVars.MARQO_LOG_LEVEL: "info",
EnvVars.MARQO_MEDIA_DOWNLOAD_THREAD_COUNT_PER_REQUEST: 5,
EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST: 20,
EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE: 370 * 1024 * 1024, # 370 megabytes in bytes
# This env variable is set to "info" by default in run_marqo.sh, which overrides this value
EnvVars.MARQO_MAX_CPU_MODEL_MEMORY: 4,
EnvVars.MARQO_MAX_CUDA_MODEL_MEMORY: 4, # For multi-GPU, this is the max memory for each GPU.
Expand Down
26 changes: 21 additions & 5 deletions src/marqo/s2_inference/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from marqo import marqo_docs
from marqo.api.exceptions import InternalError
from marqo.tensor_search.enums import EnvVars
from marqo.core.inference.models.abstract_clip_model import AbstractCLIPModel
from marqo.core.inference.models.open_clip_model_properties import OpenCLIPModelProperties, ImagePreprocessor
from marqo.s2_inference.configs import ModelCache
Expand Down Expand Up @@ -150,7 +151,7 @@ def validate_url(url: str) -> bool:



def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000, modality: Optional[str] = None) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.

Args:
Expand All @@ -171,7 +172,7 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
try:
encoded_url = encode_url(image_path)
except UnicodeEncodeError as e:
raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
raise ImageDownloadError(f"Marqo encountered an error when downloading the media url {image_path}. "
f"The url could not be encoded properly. Original error: {e}")
buffer = BytesIO()
c = pycurl.Curl()
Expand All @@ -185,15 +186,30 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
headers.update(image_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])

# callback to check file size for video and audio
if modality in ["video", "audio"]:
max_size = EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE
def progress(download_total, downloaded, upload_total, uploaded):
if downloaded > max_size:
return 1 # Abort the download
c.setopt(pycurl.NOPROGRESS, False)
c.setopt(pycurl.XFERINFOFUNCTION, progress)

try:
c.perform()
if c.getinfo(pycurl.RESPONSE_CODE) != 200:
raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}")
raise ImageDownloadError(f"media url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}")
except pycurl.error as e:
raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
f"The original error is: {e}")
error_message = str(e)
if len(e.args) > 0:
error_code = e.args[0]
if error_code == pycurl.E_ABORTED_BY_CALLBACK:
error_message = f"Media file `{image_path}` exceeds the maximum allowed size for {modality}."
raise ImageDownloadError(f"Marqo encountered an error when downloading the media url {image_path}. "
f"The original error is: {error_message}")
finally:
c.close()

buffer.seek(0)
return buffer

Expand Down
8 changes: 4 additions & 4 deletions src/marqo/s2_inference/multimodal_model_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
if isinstance(content, str) and "http" in content:
suffix = ".mp4" if modality == Modality.VIDEO else ".wav"
with self._temp_file(suffix) as temp_filename:
self._download_content(content, temp_filename)
self._download_content(content, temp_filename, modality)
preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt')
inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values']

Expand All @@ -300,11 +300,11 @@ def encode(self, content, modality, normalize=True, **kwargs):

return embeddings.cpu().numpy()

def _download_content(self, url, filename):
def _download_content(self, url, filename, modality):
# 3 seconds for images, 20 seconds for audio and video
timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000

buffer = download_image_from_url(url, {}, timeout_ms)
buffer = download_image_from_url(url, {}, timeout_ms, modality)

with open(filename, 'wb') as f:
f.write(buffer.getvalue())
f.write(buffer.getvalue())
1 change: 1 addition & 0 deletions src/marqo/tensor_search/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class EnvVars:
MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST = "MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST"
MARQO_ROOT_PATH = "MARQO_ROOT_PATH"
MARQO_MAX_CPU_MODEL_MEMORY = "MARQO_MAX_CPU_MODEL_MEMORY"
MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE = "MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE"
MARQO_MAX_CUDA_MODEL_MEMORY = "MARQO_MAX_CUDA_MODEL_MEMORY"
MARQO_EF_CONSTRUCTION_MAX_VALUE = "MARQO_EF_CONSTRUCTION_MAX_VALUE"
MARQO_MAX_VECTORISE_BATCH_SIZE = "MARQO_MAX_VECTORISE_BATCH_SIZE"
Expand Down
109 changes: 108 additions & 1 deletion tests/s2_inference/test_image_downloading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, ANY

import pycurl
import pytest
Expand All @@ -10,6 +10,8 @@
from marqo.s2_inference.clip_utils import encode_url, download_image_from_url
from marqo.s2_inference.errors import ImageDownloadError
from tests.marqo_test import MockHttpServer
from marqo.tensor_search.enums import EnvVars
from io import BytesIO


@pytest.mark.unittest
Expand Down Expand Up @@ -90,3 +92,108 @@ def test_download_image_from_url_handlesRedirection(self):
with MockHttpServer(app).run_in_thread() as base_url:
result = download_image_from_url(f'{base_url}/missing_image.jpg', image_download_headers={})
self.assertEqual(result.getvalue(), image_content)

@patch('marqo.s2_inference.clip_utils.pycurl.Curl')
@patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit
def test_video_audio_file_size_check_over_limit(self, mock_curl):
# Setup
test_url = "http://ipv4.download.thinkbroadband.com:8080/5GB.zip"
mock_curl_instance = MagicMock()
mock_curl.return_value = mock_curl_instance

# Store the progress callback
progress_callback = None

def simulate_setopt(option, value):
nonlocal progress_callback
if option == pycurl.XFERINFOFUNCTION:
progress_callback = value
elif option == pycurl.WRITEFUNCTION:
# Simulate writing some data
value(b'Some data')

mock_curl_instance.setopt.side_effect = simulate_setopt

# Simulate pycurl.error with E_ABORTED_BY_CALLBACK
mock_curl_instance.perform.side_effect = pycurl.error(pycurl.E_ABORTED_BY_CALLBACK, "Callback aborted")

# Test
with self.assertRaises(ImageDownloadError) as context:
download_image_from_url(test_url, {}, modality="video")

# Simulate the progress callback after download_image_from_url has set it
if progress_callback:
# Simulate downloading more than the limit
progress_callback(0, 6_000_000, 0, 0)

# Assert
self.assertIn("exceeds the maximum allowed size", str(context.exception))
mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False)
mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY)

@patch('marqo.s2_inference.clip_utils.pycurl.Curl')
@patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit
def test_video_audio_file_size_check_under_limit(self, mock_curl):
# Setup
test_url = "http://example.com/small_video.mp4"
mock_curl_instance = MagicMock()
mock_curl.return_value = mock_curl_instance

# Store the progress callback
progress_callback = None

def simulate_setopt(option, value):
nonlocal progress_callback
if option == pycurl.XFERINFOFUNCTION:
progress_callback = value

mock_curl_instance.setopt.side_effect = simulate_setopt
mock_curl_instance.getinfo.return_value = 200 # Simulate successful HTTP response

# Test
try:
result = download_image_from_url(test_url, {}, modality="audio")

# Simulate the progress callback after download_image_from_url has set it
if progress_callback:
progress_callback(0, 3_000_000, 0, 0)

self.assertIsInstance(result, BytesIO)
except ImageDownloadError:
self.fail("ImageDownloadError raised unexpectedly for file under size limit")

# Assert
mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False)
mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY)
mock_curl_instance.perform.assert_called_once()

@patch('marqo.s2_inference.clip_utils.pycurl.Curl')
@patch('marqo.s2_inference.clip_utils.EnvVars.MARQO_MAX_VIDEO_AUDIO_SEARCH_FILE_SIZE', 5_000_000) # 5MB limit
def test_image_file_size_not_checked(self, mock_curl):
# Setup
test_url = "http://example.com/large_image.jpg"
mock_curl_instance = MagicMock()
mock_curl.return_value = mock_curl_instance

# Simulate successful download
mock_curl_instance.getinfo.return_value = 200

# Test
try:
result = download_image_from_url(test_url, {}, modality="image")
self.assertIsInstance(result, BytesIO)
except ImageDownloadError:
self.fail("ImageDownloadError raised unexpectedly for image modality")

# Assert
mock_curl_instance.setopt.assert_any_call(pycurl.URL, test_url)
mock_curl_instance.setopt.assert_any_call(pycurl.WRITEDATA, ANY)

# Check that NOPROGRESS and XFERINFOFUNCTION are not set for image modality
with self.assertRaises(AssertionError):
mock_curl_instance.setopt.assert_any_call(pycurl.NOPROGRESS, False)
with self.assertRaises(AssertionError):
mock_curl_instance.setopt.assert_any_call(pycurl.XFERINFOFUNCTION, ANY)

# Verify that perform was called
mock_curl_instance.perform.assert_called_once()
Loading