Skip to content

Commit

Permalink
Add a default User-Agent header when downloading images using pycurl (#…
Browse files Browse the repository at this point in the history
…932)

* Add a default 'User-Agent: Marqobot/1.0' header to all request to download images
* Follow redirections when downloading images
  • Loading branch information
papa99do authored Aug 14, 2024
1 parent f2667bd commit cfac46d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 9 deletions.
19 changes: 12 additions & 7 deletions src/marqo/s2_inference/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
BICUBIC = InterpolationMode.BICUBIC
DEFAULT_HEADERS = {'User-Agent': 'Marqobot/1.0'}


def get_allowed_image_types():
Expand Down Expand Up @@ -153,16 +154,20 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
f"The url could not be encoded properly. Original error: {e}")
buffer = BytesIO()
c = pycurl.Curl()
c.setopt(c.CAINFO, certifi.where())
c.setopt(c.URL, encoded_url)
c.setopt(c.WRITEDATA, buffer)
c.setopt(c.TIMEOUT_MS, timeout_ms)
c.setopt(c.HTTPHEADER, [f"{k}: {v}" for k, v in image_download_headers.items()])
c.setopt(pycurl.CAINFO, certifi.where())
c.setopt(pycurl.URL, encoded_url)
c.setopt(pycurl.WRITEDATA, buffer)
c.setopt(pycurl.TIMEOUT_MS, timeout_ms)
c.setopt(pycurl.FOLLOWLOCATION, 1)

headers = DEFAULT_HEADERS.copy()
headers.update(image_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])

try:
c.perform()
if c.getinfo(c.RESPONSE_CODE) != 200:
raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(c.RESPONSE_CODE)}")
if c.getinfo(pycurl.RESPONSE_CODE) != 200:
raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}")
except pycurl.error as e:
raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
f"The original error is: {e}")
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.11.0"
__version__ = "2.11.1"

def get_version() -> str:
return f"{__version__}"
41 changes: 41 additions & 0 deletions tests/marqo_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import contextlib
import socket
import threading
import time
import unittest
import uuid
from typing import Generator
from unittest.mock import patch, Mock

import uvicorn
import vespa.application as pyvespa
from starlette.applications import Starlette

from marqo import config, version
from marqo.vespa.zookeeper_client import ZookeeperClient
Expand Down Expand Up @@ -307,3 +313,38 @@ def assertRaisesStrict(self, expected_exception):

class AsyncMarqoTestCase(unittest.IsolatedAsyncioTestCase, MarqoTestCase):
pass


class MockHttpServer:
"""
A MockHttpServer that takes a Starlette app as input, start the uvicorn server
in a thread, and yield the server url (with random port binding). After the test,
it automatically shuts down the server.
This can be used in individual tests, or as a test fixture in class or module scope.
Example usage:
app = Starlette(routes=[
Route('/path1', lambda _: Response({"a":"b"}, status_code=200)),
Route('/image.jpg', lambda _: Response(b'\x00\x00\x00\xff', media_type='image/png')),
])
with MockHttpServer(app).run_in_thread() as base_url:
run_some_tests
"""
def __init__(self, app: Starlette):
self.server = uvicorn.Server(config=uvicorn.Config(app=app))

@contextlib.contextmanager
def run_in_thread(self) -> Generator[str, None, None]:
(sock := socket.socket()).bind(("127.0.0.1", 0))
thread = threading.Thread(target=self.server.run, kwargs={"sockets": [sock]})
thread.start()
try:
while not self.server.started:
time.sleep(1)
address, port = sock.getsockname()
yield f'http://{address}:{port}'
finally:
self.server.should_exit = True
thread.join()
44 changes: 43 additions & 1 deletion tests/s2_inference/test_image_downloading.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from unittest import TestCase
from unittest.mock import patch, MagicMock

import pycurl
import pytest
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Route

from marqo.s2_inference.clip_utils import encode_url, download_image_from_url
from marqo.s2_inference.errors import ImageDownloadError
from tests.marqo_test import MockHttpServer


@pytest.mark.unittest
Expand Down Expand Up @@ -47,4 +53,40 @@ def test_download_image_from_url_handleDifferentUrlsCorrectly(self):
for url, expected, msg in self.test_cases:
with self.subTest(url=url, expected=expected, msg=msg):
with self.assertRaises(ImageDownloadError) as cm:
download_image_from_url(image_path=url + ".jpg", image_download_headers={})
download_image_from_url(image_path=url + ".jpg", image_download_headers={})

def test_download_image_from_url_handlesUrlRequiringUserAgentHeader(self):
url_requiring_user_agent_header = "https://docs.marqo.ai/2.0.0/Examples/marqo.jpg"
try:
download_image_from_url(image_path=url_requiring_user_agent_header, image_download_headers={})
except Exception as e:
self.fail(f"Exception was raised when downloading {url_requiring_user_agent_header}: {e}")

@patch('pycurl.Curl')
def test_download_image_from_url_mergesDefaultHeadersWithCustomHeaders(self, mock_curl):
mock_curl_instance = mock_curl.return_value
mock_curl_instance.setopt = MagicMock()
mock_curl_instance.perform = MagicMock()
mock_curl_instance.getinfo = MagicMock(return_value=200)

test_cases = [
({}, ['User-Agent: Marqobot/1.0'], "Empty header"),
({'a': 'b'}, ['User-Agent: Marqobot/1.0', 'a: b'], "Basic headers"),
({'User-Agent': 'Marqobot-Image/1.0'}, ['User-Agent: Marqobot-Image/1.0'], "Headers with override"),
]

for (headers, expected_headers, msg) in test_cases:
with self.subTest(headers=headers, expected_headers=expected_headers, msg=msg):
download_image_from_url('http://example.com/image.jpg', image_download_headers=headers)
mock_curl_instance.setopt.assert_called_with(pycurl.HTTPHEADER, expected_headers)

def test_download_image_from_url_handlesRedirection(self):
image_content = b'\x00\x00\x00\xff'
app = Starlette(routes=[
Route('/missing_image.jpg', lambda _: Response(status_code=301, headers={'Location': '/image.jpg'})),
Route('/image.jpg', lambda _: Response(image_content, media_type='image/png')),
])

with MockHttpServer(app).run_in_thread() as base_url:
result = download_image_from_url(f'{base_url}/missing_image.jpg', image_download_headers={})
self.assertEqual(result.getvalue(), image_content)

0 comments on commit cfac46d

Please sign in to comment.