Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Changes for line-item prompt type #880

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"boolean":"boolean",
"json":"json",
"table":"table",
"record":"record"
"record":"record",
"line_item":"line-item"
},
"output_processing":{
"DEFAULT":"Default"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 4.2.1 on 2024-12-10 04:13

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("prompt_studio_v2", "0002_alter_toolstudioprompt_enforce_type"),
]

operations = [
migrations.AlterField(
model_name="toolstudioprompt",
name="enforce_type",
field=models.TextField(
blank=True,
choices=[
("Text", "Response sent as Text"),
("number", "Response sent as number"),
("email", "Response sent as email"),
("date", "Response sent as date"),
("boolean", "Response sent as boolean"),
("json", "Response sent as json"),
("table", "Response sent as table"),
(
"record",
"Response sent for records. Entries of records are list of logical and organized individual entities with distint values",
),
("line-item", "Response sent as line-item"),
],
db_comment="Field to store the type in which the response to be returned.",
default="Text",
),
),
]
1 change: 1 addition & 0 deletions backend/prompt_studio/prompt_studio_v2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EnforceType(models.TextChoices):
"logical and organized individual "
"entities with distint values"
)
LINE_ITEM = "line-item", ("Response sent as line-item")

class PromptType(models.TextChoices):
PROMPT = "PROMPT", "Response sent as Text"
Expand Down
1 change: 1 addition & 0 deletions prompt-service/src/unstract/prompt_service/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class PromptServiceContants:
ENABLE_HIGHLIGHT = "enable_highlight"
FILE_PATH = "file_path"
HIGHLIGHT_DATA = "highlight_data"
LINE_ITEM = "line-item"


class RunLevel(Enum):
Expand Down
64 changes: 62 additions & 2 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from unstract.sdk.exceptions import SdkError
from unstract.sdk.llm import LLM

PAID_FEATURE_MSG = (
"It is a cloud / enterprise feature. If you have purchased a plan and still "
"face this issue, please contact support"
)

load_dotenv()

# Global variable to store plugins
Expand Down Expand Up @@ -295,8 +300,8 @@ def run_completion(
extract_json=prompt_type.lower() != PSKeys.TEXT,
)
answer: str = completion[PSKeys.RESPONSE].text
highlight_data = completion.get(PSKeys.HIGHLIGHT_DATA)
if all([metadata, highlight_data, prompt_key]):
highlight_data = completion.get(PSKeys.HIGHLIGHT_DATA, [])
if all([metadata, prompt_key]):
metadata.setdefault(PSKeys.HIGHLIGHT_DATA, {})[prompt_key] = highlight_data
return answer
# TODO: Catch and handle specific exception here
Expand Down Expand Up @@ -333,3 +338,58 @@ def extract_table(
except table_extractor["exception_cls"] as e:
msg = f"Couldn't extract table. {e}"
raise APIError(message=msg)


def extract_line_item(
tool_settings: dict[str, Any],
output: dict[str, Any],
plugins: dict[str, dict[str, Any]],
structured_output: dict[str, Any],
llm: LLM,
file_path: str,
) -> dict[str, Any]:
# Adjust file path to read from the extract folder
base_name = os.path.splitext(os.path.basename(file_path))[
0
] # Get the base name without extension
extract_file_path = os.path.join(
os.path.dirname(file_path), "extract", f"{base_name}.txt"
)

# Read file content into context
if not os.path.exists(extract_file_path):
raise FileNotFoundError(
f"The file at path '{extract_file_path}' does not exist."
)

with open(extract_file_path, encoding="utf-8") as file:
context = file.read()

prompt = construct_prompt(
preamble=tool_settings.get(PSKeys.PREAMBLE, ""),
prompt=output["promptx"],
postamble=tool_settings.get(PSKeys.POSTAMBLE, ""),
grammar_list=tool_settings.get(PSKeys.GRAMMAR, []),
context=context,
platform_postamble="",
)
line_item_extraction_plugin: dict[str, Any] = plugins.get(
"line-item-extraction", {}
)
if not line_item_extraction_plugin:
raise APIError(PAID_FEATURE_MSG)
try:
line_item_extraction = line_item_extraction_plugin["entrypoint_cls"](
llm=llm,
tool_settings=tool_settings,
output=output,
prompt=prompt,
structured_output=structured_output,
logger=current_app.logger,
)
answer = line_item_extraction.run()
structured_output[output[PSKeys.NAME]] = answer
return structured_output
except line_item_extraction_plugin["exception_cls"] as e:
msg = f"Couldn't extract table. {e}"
raise APIError(message=msg)
35 changes: 35 additions & 0 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError
from unstract.prompt_service.helper import (
construct_and_run_prompt,
extract_line_item,
extract_table,
extract_variable,
get_cleaned_context,
Expand Down Expand Up @@ -250,6 +251,40 @@ def prompt_processor() -> Any:
"Error while extracting table for the prompt",
)
raise api_error
elif output[PSKeys.TYPE] == PSKeys.LINE_ITEM:
try:
structured_output = extract_line_item(
tool_settings=tool_settings,
output=output,
plugins=plugins,
structured_output=structured_output,
llm=llm,
file_path=file_path,
)
metadata = query_usage_metadata(token=platform_key, metadata=metadata)
response = {
PSKeys.METADATA: metadata,
PSKeys.OUTPUT: structured_output,
}
return response
except APIError as e:
app.logger.error(
"Failed to extract line-item for the prompt %s: %s",
output[PSKeys.NAME],
str(e),
)
publish_log(
log_events_id,
{
"tool_id": tool_id,
"prompt_key": prompt_name,
"doc_name": doc_name,
},
LogLevel.ERROR,
RunLevel.RUN,
"Error while extracting line-item for the prompt",
)
raise e

try:
context: set[str] = set()
Expand Down