-
Notifications
You must be signed in to change notification settings - Fork 214
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] Variable replacement support in prompt studio (#600)
* Feat/Variable replacement support in prompt studio * Rename method name Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * Update prompt-service/src/unstract/prompt_service/main.py Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * Output variable * Exception handling for API exceptions * Handling variable replacement method in else * enhancements to exit method early Co-authored-by: Chandrasekharan M <[email protected]> Signed-off-by: harini-venkataraman <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Method renames --------- Signed-off-by: harini-venkataraman <[email protected]> Co-authored-by: Chandrasekharan M <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b96537f
commit 4de6ab9
Showing
11 changed files
with
315 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
backend/prompt_studio/prompt_studio_core/prompt_variable_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import re | ||
from enum import Enum | ||
from typing import Any | ||
|
||
from prompt_studio.prompt_studio.models import ToolStudioPrompt | ||
from prompt_studio.prompt_studio_core.exceptions import PromptNotRun | ||
from prompt_studio.prompt_studio_output_manager.models import PromptStudioOutputManager | ||
|
||
|
||
class VariableType(str, Enum): | ||
STATIC = "STATIC" | ||
DYNAMIC = "DYNAMIC" | ||
|
||
|
||
class VariableConstants: | ||
|
||
VARIABLE_REGEX = "{{(.+?)}}" | ||
DYNAMIC_VARIABLE_DATA_REGEX = r"\[(.*?)\]" | ||
DYNAMIC_VARIABLE_URL_REGEX = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))" # noqa: E501 | ||
|
||
|
||
class PromptStudioVariableService: | ||
|
||
@staticmethod | ||
def fetch_variable_outputs(variable: str, doc_id: str, tool_id: str) -> Any: | ||
variable_prompt: ToolStudioPrompt = ToolStudioPrompt.objects.get( | ||
prompt_key=variable, tool_id=tool_id | ||
) | ||
output = PromptStudioOutputManager.objects.get( | ||
prompt_id=variable_prompt.prompt_id, | ||
document_manager=doc_id, | ||
tool_id=variable_prompt.tool_id, | ||
profile_manager=variable_prompt.profile_manager, | ||
) | ||
if not output: | ||
raise PromptNotRun() | ||
return output.output | ||
|
||
@staticmethod | ||
def identify_variable_type(variable: str) -> VariableType: | ||
variable_type: VariableType | ||
pattern = re.compile(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX) | ||
if re.findall(pattern, variable): | ||
variable_type = VariableType.DYNAMIC | ||
else: | ||
variable_type = VariableType.STATIC | ||
return variable_type | ||
|
||
@staticmethod | ||
def extract_variables_from_prompt(prompt: str) -> list[str]: | ||
variable: list[str] = [] | ||
variable = re.findall(VariableConstants.VARIABLE_REGEX, prompt) | ||
return variable | ||
|
||
@staticmethod | ||
def frame_variable_replacement_map( | ||
doc_id: str, prompt_object: ToolStudioPrompt | ||
) -> dict[str, Any]: | ||
variable_output_map: dict[str, Any] = {} | ||
prompt = prompt_object.prompt | ||
variables = PromptStudioVariableService.extract_variables_from_prompt( | ||
prompt=prompt | ||
) | ||
for variable in variables: | ||
variable_type: VariableType = ( | ||
PromptStudioVariableService.identify_variable_type(variable=variable) | ||
) | ||
if variable_type == VariableType.STATIC: | ||
variable_output_map[variable] = ( | ||
PromptStudioVariableService.fetch_variable_outputs( | ||
variable=variable, | ||
doc_id=doc_id, | ||
tool_id=prompt_object.tool_id.tool_id, | ||
) | ||
) | ||
if variable_type == VariableType.DYNAMIC: | ||
variable = re.findall( | ||
VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable | ||
)[0] | ||
variable_output_map[variable] = ( | ||
PromptStudioVariableService.fetch_variable_outputs( | ||
variable=variable, | ||
doc_id=doc_id, | ||
tool_id=prompt_object.tool_id.tool_id, | ||
) | ||
) | ||
return variable_output_map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
prompt-service/src/unstract/prompt_service/utils/request.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import logging | ||
from enum import Enum | ||
from typing import Any, Optional | ||
|
||
import requests as pyrequests | ||
from requests.exceptions import RequestException | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HTTPMethod(str, Enum): | ||
GET = "GET" | ||
POST = "POST" | ||
DELETE = "DELETE" | ||
PUT = "PUT" | ||
PATCH = "PATCH" | ||
|
||
|
||
def make_http_request( | ||
verb: HTTPMethod, | ||
url: str, | ||
data: Optional[dict[str, Any]] = None, | ||
headers: Optional[dict[str, Any]] = None, | ||
params: Optional[dict[str, Any]] = None, | ||
) -> str: | ||
"""Generic helper function to help make a HTTP request.""" | ||
try: | ||
if verb == HTTPMethod.GET: | ||
response = pyrequests.get(url, params=params, headers=headers) | ||
elif verb == HTTPMethod.POST: | ||
response = pyrequests.post(url, json=data, params=params, headers=headers) | ||
elif verb == HTTPMethod.DELETE: | ||
response = pyrequests.delete(url, params=params, headers=headers) | ||
else: | ||
raise ValueError("Invalid HTTP verb. Supported verbs: GET, POST, DELETE") | ||
|
||
response.raise_for_status() | ||
return_val: str = ( | ||
response.json() | ||
if response.headers.get("content-type") == "application/json" | ||
else response.text | ||
) | ||
return return_val | ||
except RequestException as e: | ||
logger.error(f"HTTP request error: {e}") | ||
raise e | ||
except Exception as e: | ||
logger.error(f"An unexpected error occurred: {e}") | ||
raise e |
25 changes: 25 additions & 0 deletions
25
prompt-service/src/unstract/prompt_service/variable_extractor/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Any | ||
|
||
from .constants import VariableType | ||
from .prompt_variable_service import VariableService | ||
|
||
|
||
class VariableExtractor: | ||
|
||
@staticmethod | ||
def execute_variable_replacement(prompt: str, variable_map: dict[str, Any]) -> str: | ||
variables: list[str] = VariableService.extract_variables_from_prompt( | ||
prompt=prompt | ||
) | ||
for variable in variables: | ||
variable_type = VariableService.identify_variable_type(variable=variable) | ||
if variable_type == VariableType.STATIC: | ||
prompt = VariableService.replace_static_variable( | ||
prompt=prompt, structured_output=variable_map, variable=variable | ||
) | ||
|
||
if variable_type == VariableType.DYNAMIC: | ||
prompt = VariableService.replace_dynamic_variable( | ||
prompt=prompt, variable=variable | ||
) | ||
return prompt |
13 changes: 13 additions & 0 deletions
13
prompt-service/src/unstract/prompt_service/variable_extractor/constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from enum import Enum | ||
|
||
|
||
class VariableType(str, Enum): | ||
STATIC = "STATIC" | ||
DYNAMIC = "DYNAMIC" | ||
|
||
|
||
class VariableConstants: | ||
|
||
VARIABLE_REGEX = "{{(.+?)}}" | ||
DYNAMIC_VARIABLE_DATA_REGEX = r"\[(.*?)\]" | ||
DYNAMIC_VARIABLE_URL_REGEX = r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»“”‘’]))" # noqa: E501 |
93 changes: 93 additions & 0 deletions
93
prompt-service/src/unstract/prompt_service/variable_extractor/prompt_variable_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import logging | ||
import re | ||
from typing import Any | ||
|
||
from utils.request import HTTPMethod, make_http_request | ||
|
||
from .constants import VariableConstants, VariableType | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VariableService: | ||
|
||
@staticmethod | ||
def replace_static_variable( | ||
prompt: str, structured_output: dict[str, Any], variable: str | ||
) -> str: | ||
output_value = VariableService.check_static_variable_run_status( | ||
structure_output=structured_output, variable=variable | ||
) | ||
if not output_value: | ||
return prompt | ||
static_variable_marker_string = "".join(["{{", variable, "}}"]) | ||
|
||
replaced_prompt: str = VariableService.replace_generic_string_value( | ||
prompt=prompt, variable=static_variable_marker_string, value=output_value | ||
) | ||
|
||
return replaced_prompt | ||
|
||
@staticmethod | ||
def check_static_variable_run_status( | ||
structure_output: dict[str, Any], variable: str | ||
) -> Any: | ||
output = None | ||
try: | ||
output = structure_output[variable] | ||
return output | ||
except KeyError: | ||
logger.warn( | ||
f"Prompt with {variable} is not executed yet." | ||
" Unable to replace the variable" | ||
) | ||
return output | ||
|
||
@staticmethod | ||
def replace_generic_string_value(prompt: str, variable: str, value: str) -> str: | ||
replaced_prompt = re.sub(variable, value, prompt) | ||
return replaced_prompt | ||
|
||
@staticmethod | ||
def identify_variable_type(variable: str) -> VariableType: | ||
variable_type: VariableType | ||
pattern = re.compile(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX) | ||
if re.findall(pattern, variable): | ||
variable_type = VariableType.DYNAMIC | ||
else: | ||
variable_type = VariableType.STATIC | ||
return variable_type | ||
|
||
@staticmethod | ||
def replace_dynamic_variable(prompt: str, variable: str) -> str: | ||
url = re.search(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX, variable).group(0) | ||
data = re.findall(VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable)[0] | ||
api_response = VariableService.fetch_dynamic_variable_value(url=url, data=data) | ||
static_variable_marker_string = "".join(["{{", variable, "}}"]) | ||
replaced_prompt: str = VariableService.replace_generic_string_value( | ||
prompt=prompt, variable=static_variable_marker_string, value=api_response | ||
) | ||
return replaced_prompt | ||
|
||
@staticmethod | ||
def extract_variables_from_prompt(prompt: str) -> list[str]: | ||
variable: list[str] = [] | ||
variable = re.findall(VariableConstants.VARIABLE_REGEX, prompt) | ||
return variable | ||
|
||
@staticmethod | ||
def fetch_dynamic_variable_value(url: str, data: str) -> Any: | ||
|
||
# This prototype method currently supports | ||
# only endpoints that do not require authentication. | ||
# Additionally, it only accepts plain text | ||
# inputs for POST requests in this version. | ||
# Future versions may include support for | ||
# authentication and other input formats. | ||
|
||
verb: HTTPMethod = HTTPMethod.POST | ||
headers = {"Content-Type": "text/plain"} | ||
response: Any = make_http_request( | ||
verb=verb, url=url, data=data, headers=headers | ||
) | ||
return response |