Skip to content

Commit

Permalink
Merge pull request #11 from ShoaibMajidDar/main
Browse files Browse the repository at this point in the history
added encoder to whisper function in LLMWhisperClient
  • Loading branch information
hari-kuriakose authored Oct 30, 2024
2 parents a7a58d4 + be02721 commit 2929529
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/unstract/llmwhisperer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def whisper(
ocr_provider: str = "advanced",
line_splitter_tolerance: float = 0.4,
horizontal_stretch_factor: float = 1.0,
encoding: str = "utf-8"
) -> dict:
"""
Sends a request to the LLMWhisperer API to process a document.
Expand All @@ -190,6 +191,7 @@ def whisper(
ocr_provider (str, optional): The OCR provider. Can be "advanced" or "basic". Defaults to "advanced".
line_splitter_tolerance (float, optional): The line splitter tolerance. Defaults to 0.4.
horizontal_stretch_factor (float, optional): The horizontal stretch factor. Defaults to 1.0.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".
Returns:
dict: The response from the API as a dictionary.
Expand Down Expand Up @@ -268,6 +270,7 @@ def generate():
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout, stream=should_stream)
response.encoding = encoding
if response.status_code != 200 and response.status_code != 202:
message = json.loads(response.text)
message["status_code"] = response.status_code
Expand Down Expand Up @@ -318,7 +321,7 @@ def whisper_status(self, whisper_hash: str) -> dict:
message["status_code"] = response.status_code
return message

def whisper_retrieve(self, whisper_hash: str) -> dict:
def whisper_retrieve(self, whisper_hash: str, encoding: str = "utf-8") -> dict:
"""Retrieves the result of the whisper operation from the LLMWhisperer
API.
Expand All @@ -329,6 +332,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:
Args:
whisper_hash (str): The hash of the whisper operation.
encoding (str): The character encoding to use for processing the text. Defaults to "utf-8".
Returns:
dict: A dictionary containing the status code and the extracted text from the whisper operation.
Expand All @@ -345,6 +349,7 @@ def whisper_retrieve(self, whisper_hash: str) -> dict:
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=self.api_timeout)
response.encoding = encoding
if response.status_code != 200:
err = json.loads(response.text)
err["status_code"] = response.status_code
Expand Down

0 comments on commit 2929529

Please sign in to comment.