Skip to content

Commit

Permalink
[FEAT] Variable replacement support in prompt studio (#600)
Browse files Browse the repository at this point in the history
* Feat/Variable replacement support in prompt studio

* Rename method name

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: harini-venkataraman <[email protected]>

* Update prompt-service/src/unstract/prompt_service/main.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: harini-venkataraman <[email protected]>

* Output variable

* Exception handling for API exceptions

* Handling variable replacement method in else

* enhancements to exit method early

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: harini-venkataraman <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Method renames

---------

Signed-off-by: harini-venkataraman <[email protected]>
Co-authored-by: Chandrasekharan M <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent b96537f commit 4de6ab9
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 2 deletions.
1 change: 1 addition & 0 deletions backend/prompt_studio/prompt_studio_core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ToolStudioPromptKeys:
EXTRACT = "extract"
PLATFORM_POSTAMBLE = "platform_postamble"
SUMMARIZE_AS_SOURCE = "summarize_as_source"
VARIABLE_MAP = "variable_map"


class FileViewTypes:
Expand Down
9 changes: 9 additions & 0 deletions backend/prompt_studio/prompt_studio_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ class OperationNotSupported(APIException):
"Please check our cloud or enterprise on-premise offering "
"for access to this functionality."
)


class PromptNotRun(APIException):
status_code = 403
default_detail = (
"The prompt must be executed before "
"it can be used as a variable in another prompt. "
"Please execute the prompt first and try again."
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
)
from prompt_studio.prompt_studio_core.models import CustomTool
from prompt_studio.prompt_studio_core.prompt_ide_base_tool import PromptIdeBaseTool
from prompt_studio.prompt_studio_core.prompt_variable_service import (
PromptStudioVariableService,
)
from prompt_studio.prompt_studio_document_manager.models import DocumentManager
from prompt_studio.prompt_studio_index_manager.prompt_studio_index_helper import ( # noqa: E501
PromptStudioIndexHelper,
Expand Down Expand Up @@ -765,7 +768,11 @@ def _fetch_response(
output = PromptStudioHelper.fetch_table_settings_if_enabled(
doc_name, prompt, org_id, user_id, tool_id, output
)

variable_map = PromptStudioVariableService.frame_variable_replacement_map(
doc_id=document_id, prompt_object=prompt
)
if variable_map:
output[TSPKeys.VARIABLE_MAP] = variable_map
outputs.append(output)

tool_settings = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import re
from enum import Enum
from typing import Any

from prompt_studio.prompt_studio.models import ToolStudioPrompt
from prompt_studio.prompt_studio_core.exceptions import PromptNotRun
from prompt_studio.prompt_studio_output_manager.models import PromptStudioOutputManager


class VariableType(str, Enum):
STATIC = "STATIC"
DYNAMIC = "DYNAMIC"


class VariableConstants:

VARIABLE_REGEX = "{{(.+?)}}"
DYNAMIC_VARIABLE_DATA_REGEX = r"\[(.*?)\]"
DYNAMIC_VARIABLE_URL_REGEX = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))" # noqa: E501


class PromptStudioVariableService:

@staticmethod
def fetch_variable_outputs(variable: str, doc_id: str, tool_id: str) -> Any:
variable_prompt: ToolStudioPrompt = ToolStudioPrompt.objects.get(
prompt_key=variable, tool_id=tool_id
)
output = PromptStudioOutputManager.objects.get(
prompt_id=variable_prompt.prompt_id,
document_manager=doc_id,
tool_id=variable_prompt.tool_id,
profile_manager=variable_prompt.profile_manager,
)
if not output:
raise PromptNotRun()
return output.output

@staticmethod
def identify_variable_type(variable: str) -> VariableType:
variable_type: VariableType
pattern = re.compile(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX)
if re.findall(pattern, variable):
variable_type = VariableType.DYNAMIC
else:
variable_type = VariableType.STATIC
return variable_type

@staticmethod
def extract_variables_from_prompt(prompt: str) -> list[str]:
variable: list[str] = []
variable = re.findall(VariableConstants.VARIABLE_REGEX, prompt)
return variable

@staticmethod
def frame_variable_replacement_map(
doc_id: str, prompt_object: ToolStudioPrompt
) -> dict[str, Any]:
variable_output_map: dict[str, Any] = {}
prompt = prompt_object.prompt
variables = PromptStudioVariableService.extract_variables_from_prompt(
prompt=prompt
)
for variable in variables:
variable_type: VariableType = (
PromptStudioVariableService.identify_variable_type(variable=variable)
)
if variable_type == VariableType.STATIC:
variable_output_map[variable] = (
PromptStudioVariableService.fetch_variable_outputs(
variable=variable,
doc_id=doc_id,
tool_id=prompt_object.tool_id.tool_id,
)
)
if variable_type == VariableType.DYNAMIC:
variable = re.findall(
VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable
)[0]
variable_output_map[variable] = (
PromptStudioVariableService.fetch_variable_outputs(
variable=variable,
doc_id=doc_id,
tool_id=prompt_object.tool_id.tool_id,
)
)
return variable_output_map
3 changes: 2 additions & 1 deletion backend/utils/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import requests as pyrequests
from requests.exceptions import RequestException
from unstract.prompt_service.exceptions import APIError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,7 +44,7 @@ def make_http_request(
return return_val
except RequestException as e:
logger.error(f"HTTP request error: {e}")
raise e
raise APIError(f"Error occured while invoking POST API Variable : {str(e)}")
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise e
1 change: 1 addition & 0 deletions prompt-service/src/unstract/prompt_service/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class PromptServiceContants:
EXTRACT_EPILOGUE = "extract-epilogue"
CLEAN_CONTEXT = "clean-context"
SUMMARIZE_AS_SOURCE = "summarize_as_source"
VARIABLE_MAP = "variable_map"


class LogLevel(Enum):
Expand Down
27 changes: 27 additions & 0 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
run_completion,
)
from unstract.prompt_service.prompt_ide_base_tool import PromptServiceBaseTool
from unstract.prompt_service.variable_extractor.base import VariableExtractor
from unstract.sdk.constants import LogLevel
from unstract.sdk.embedding import Embedding
from unstract.sdk.exceptions import SdkError
Expand Down Expand Up @@ -113,6 +114,32 @@ def prompt_processor() -> Any:
util = PromptServiceBaseTool(log_level=LogLevel.INFO, platform_key=platform_key)
index = Index(tool=util)

app.logger.info(f"[{tool_id}] Replacing variables in prompt : {prompt_name}")
_publish_log(
log_events_id,
{
"tool_id": tool_id,
"prompt_key": prompt_name,
"doc_name": doc_name,
},
LogLevel.DEBUG,
RunLevel.RUN,
"Replacing variables in prompt",
)
try:
variable_map = output[PSKeys.VARIABLE_MAP]
VariableExtractor.execute_variable_replacement(
prompt=promptx, variable_map=variable_map
)
except KeyError:
# Executed incase of structured tool and
# APIs where we do not set the variable map
VariableExtractor.execute_variable_replacement(
prompt=promptx, variable_map=structured_output
)
except APIError as api_error:
raise api_error

app.logger.info(f"[{tool_id}] Executing prompt: {prompt_name}")
_publish_log(
log_events_id,
Expand Down
49 changes: 49 additions & 0 deletions prompt-service/src/unstract/prompt_service/utils/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
from enum import Enum
from typing import Any, Optional

import requests as pyrequests
from requests.exceptions import RequestException

logger = logging.getLogger(__name__)


class HTTPMethod(str, Enum):
GET = "GET"
POST = "POST"
DELETE = "DELETE"
PUT = "PUT"
PATCH = "PATCH"


def make_http_request(
verb: HTTPMethod,
url: str,
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
) -> str:
"""Generic helper function to help make a HTTP request."""
try:
if verb == HTTPMethod.GET:
response = pyrequests.get(url, params=params, headers=headers)
elif verb == HTTPMethod.POST:
response = pyrequests.post(url, json=data, params=params, headers=headers)
elif verb == HTTPMethod.DELETE:
response = pyrequests.delete(url, params=params, headers=headers)
else:
raise ValueError("Invalid HTTP verb. Supported verbs: GET, POST, DELETE")

response.raise_for_status()
return_val: str = (
response.json()
if response.headers.get("content-type") == "application/json"
else response.text
)
return return_val
except RequestException as e:
logger.error(f"HTTP request error: {e}")
raise e
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise e
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from .constants import VariableType
from .prompt_variable_service import VariableService


class VariableExtractor:

@staticmethod
def execute_variable_replacement(prompt: str, variable_map: dict[str, Any]) -> str:
variables: list[str] = VariableService.extract_variables_from_prompt(
prompt=prompt
)
for variable in variables:
variable_type = VariableService.identify_variable_type(variable=variable)
if variable_type == VariableType.STATIC:
prompt = VariableService.replace_static_variable(
prompt=prompt, structured_output=variable_map, variable=variable
)

if variable_type == VariableType.DYNAMIC:
prompt = VariableService.replace_dynamic_variable(
prompt=prompt, variable=variable
)
return prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class VariableType(str, Enum):
STATIC = "STATIC"
DYNAMIC = "DYNAMIC"


class VariableConstants:

VARIABLE_REGEX = "{{(.+?)}}"
DYNAMIC_VARIABLE_DATA_REGEX = r"\[(.*?)\]"
DYNAMIC_VARIABLE_URL_REGEX = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))" # noqa: E501
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
import re
from typing import Any

from utils.request import HTTPMethod, make_http_request

from .constants import VariableConstants, VariableType

logger = logging.getLogger(__name__)


class VariableService:

@staticmethod
def replace_static_variable(
prompt: str, structured_output: dict[str, Any], variable: str
) -> str:
output_value = VariableService.check_static_variable_run_status(
structure_output=structured_output, variable=variable
)
if not output_value:
return prompt
static_variable_marker_string = "".join(["{{", variable, "}}"])

replaced_prompt: str = VariableService.replace_generic_string_value(
prompt=prompt, variable=static_variable_marker_string, value=output_value
)

return replaced_prompt

@staticmethod
def check_static_variable_run_status(
structure_output: dict[str, Any], variable: str
) -> Any:
output = None
try:
output = structure_output[variable]
return output
except KeyError:
logger.warn(
f"Prompt with {variable} is not executed yet."
" Unable to replace the variable"
)
return output

@staticmethod
def replace_generic_string_value(prompt: str, variable: str, value: str) -> str:
replaced_prompt = re.sub(variable, value, prompt)
return replaced_prompt

@staticmethod
def identify_variable_type(variable: str) -> VariableType:
variable_type: VariableType
pattern = re.compile(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX)
if re.findall(pattern, variable):
variable_type = VariableType.DYNAMIC
else:
variable_type = VariableType.STATIC
return variable_type

@staticmethod
def replace_dynamic_variable(prompt: str, variable: str) -> str:
url = re.search(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX, variable).group(0)
data = re.findall(VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable)[0]
api_response = VariableService.fetch_dynamic_variable_value(url=url, data=data)
static_variable_marker_string = "".join(["{{", variable, "}}"])
replaced_prompt: str = VariableService.replace_generic_string_value(
prompt=prompt, variable=static_variable_marker_string, value=api_response
)
return replaced_prompt

@staticmethod
def extract_variables_from_prompt(prompt: str) -> list[str]:
variable: list[str] = []
variable = re.findall(VariableConstants.VARIABLE_REGEX, prompt)
return variable

@staticmethod
def fetch_dynamic_variable_value(url: str, data: str) -> Any:

# This prototype method currently supports
# only endpoints that do not require authentication.
# Additionally, it only accepts plain text
# inputs for POST requests in this version.
# Future versions may include support for
# authentication and other input formats.

verb: HTTPMethod = HTTPMethod.POST
headers = {"Content-Type": "text/plain"}
response: Any = make_http_request(
verb=verb, url=url, data=data, headers=headers
)
return response

0 comments on commit 4de6ab9

Please sign in to comment.