Skip to content

Commit

Permalink
Merge pull request #9 from Zipstack/fix/test-case-outputs-updated
Browse files Browse the repository at this point in the history
fix: Updated test case outputs and fuzzy assertion
  • Loading branch information
jaseemjaskp authored Nov 1, 2024
2 parents 2929529 + e30178e commit 0395464
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 236 deletions.
2 changes: 1 addition & 1 deletion src/unstract/llmwhisperer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.21.0"
__version__ = "0.22.0"

from .client import LLMWhispererClient # noqa: F401

Expand Down
94 changes: 21 additions & 73 deletions tests/client_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import logging
import os
import unittest
from difflib import SequenceMatcher, unified_diff
from pathlib import Path

import pytest
import requests

from unstract.llmwhisperer import LLMWhispererClient

logger = logging.getLogger(__name__)

Expand All @@ -23,9 +20,7 @@ def test_get_usage_info(client):
"subscription_plan",
"today_page_count",
]
assert set(usage_info.keys()) == set(
expected_keys
), f"usage_info {usage_info} does not contain the expected keys"
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"


@pytest.mark.parametrize(
Expand All @@ -41,85 +36,38 @@ def test_get_usage_info(client):
)
def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
file_path = os.path.join(data_dir, input_file)
response = client.whisper(
whisper_result = client.whisper(
processing_mode=processing_mode,
output_mode=output_mode,
file_path=file_path,
timeout=200,
)
logger.debug(response)
logger.debug(whisper_result)

exp_basename = f"{Path(input_file).stem}.{processing_mode}.{output_mode}.txt"
exp_file = os.path.join(data_dir, "expected", exp_basename)
with open(exp_file, encoding="utf-8") as f:
exp = f.read()

assert isinstance(response, dict)
assert response["status_code"] == 200
assert response["extracted_text"] == exp
assert_extracted_text(exp_file, whisper_result, processing_mode, output_mode)


# TODO: Review and port to pytest based tests
class TestLLMWhispererClient(unittest.TestCase):
@unittest.skip("Skipping test_whisper")
def test_whisper(self):
client = LLMWhispererClient()
# response = client.whisper(
# url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
# )
response = client.whisper(
file_path="test_data/restaurant_invoice_photo.pdf",
timeout=200,
store_metadata_for_highlighting=True,
)
print(response)
# self.assertIsInstance(response, dict)
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
with open(file_path, encoding="utf-8") as f:
exp = f.read()

# @unittest.skip("Skipping test_whisper")
def test_whisper_stream(self):
client = LLMWhispererClient()
download_url = (
"https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
)
# Create a stream of download_url and pass it to whisper
response_download = requests.get(download_url, stream=True)
response_download.raise_for_status()
response = client.whisper(
stream=response_download.iter_content(chunk_size=1024),
timeout=200,
store_metadata_for_highlighting=True,
)
print(response)
# self.assertIsInstance(response, dict)
assert isinstance(whisper_result, dict)
assert whisper_result["status_code"] == 200

@unittest.skip("Skipping test_whisper_status")
def test_whisper_status(self):
client = LLMWhispererClient()
response = client.whisper_status(
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
)
logger.info(response)
self.assertIsInstance(response, dict)
# For OCR based processing
threshold = 0.97

@unittest.skip("Skipping test_whisper_retrieve")
def test_whisper_retrieve(self):
client = LLMWhispererClient()
response = client.whisper_retrieve(
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
)
logger.info(response)
self.assertIsInstance(response, dict)
# For text based processing
if mode == "native_text" and output_mode == "text":
threshold = 0.99
extracted_text = whisper_result["extracted_text"]
similarity = SequenceMatcher(None, extracted_text, exp).ratio()

@unittest.skip("Skipping test_whisper_highlight_data")
def test_whisper_highlight_data(self):
client = LLMWhispererClient()
response = client.highlight_data(
whisper_hash="9924d865|5f1d285a7cf18d203de7af1a1abb0a3a",
search_text="Indiranagar",
if similarity < threshold:
diff = "\n".join(
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
)
logger.info(response)
self.assertIsInstance(response, dict)


if __name__ == "__main__":
unittest.main()
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
Loading

0 comments on commit 0395464

Please sign in to comment.