diff --git a/src/unstract/sdk/prompt.py b/src/unstract/sdk/prompt.py index 71ec5404..fd91762f 100644 --- a/src/unstract/sdk/prompt.py +++ b/src/unstract/sdk/prompt.py @@ -146,29 +146,19 @@ def get_exported_tool( """ platform_host = tool.get_env_or_die(ToolEnv.PLATFORM_HOST) platform_port = tool.get_env_or_die(ToolEnv.PLATFORM_PORT) - - tool.stream_log("Connecting to DB and getting exported tool metadata") base_url = SdkHelper.get_platform_base_url(platform_host, platform_port) bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) - url = f"{base_url}/custom_tool_instance" query_params = {PromptStudioKeys.PROMPT_REGISTRY_ID: prompt_registry_id} headers = {"Authorization": f"Bearer {bearer_token}"} response = requests.get(url, headers=headers, params=query_params) if response.status_code == 200: - adapter_data: dict[str, Any] = response.json() - tool.stream_log( - "Successfully retrieved metadata for the exported " - f"tool: {prompt_registry_id}" - ) - return adapter_data - + return response.json() elif response.status_code == 404: tool.stream_error_and_exit( - f"Exported tool {prompt_registry_id} is not found" + f"Exported tool '{prompt_registry_id}' is not found" ) return None - else: tool.stream_error_and_exit( f"Error while retrieving tool metadata " diff --git a/src/unstract/sdk/tool/stream.py b/src/unstract/sdk/tool/stream.py index 76c91b73..c536a560 100644 --- a/src/unstract/sdk/tool/stream.py +++ b/src/unstract/sdk/tool/stream.py @@ -1,5 +1,6 @@ import datetime import json +import logging import os from typing import Any @@ -7,6 +8,7 @@ from unstract.sdk.constants import Command, LogLevel, LogStage, ToolEnv from unstract.sdk.utils import ToolUtils +from unstract.sdk.utils.common_utils import UNSTRACT_TO_PY_LOG_LEVEL class StreamMixin: @@ -30,8 +32,35 @@ def __init__(self, log_level: LogLevel = LogLevel.INFO, **kwargs) -> None: self._exec_by_tool = ToolUtils.str_to_bool( os.environ.get(ToolEnv.EXECUTION_BY_TOOL, "False") ) + if self.is_exec_by_tool: + self._configure_logger() super().__init__(**kwargs) + @property + def is_exec_by_tool(self): + """Flag to determine if SDK library is used in a tool's context. + + Returns: + bool: True if SDK is used by a tool else False + """ + return self._exec_by_tool + + def _configure_logger(self) -> None: + """Helps configure the logger for the tool run.""" + rootlogger = logging.getLogger("") + # Avoids adding multiple handlers + if rootlogger.hasHandlers(): + return + handler = logging.StreamHandler() + handler.setLevel(level=UNSTRACT_TO_PY_LOG_LEVEL[self.log_level]) + handler.setFormatter( + logging.Formatter( + "[%(asctime)s] %(levelname)s in %(module)s: %(message)s", + ) + ) + rootlogger.addHandler(handler) + rootlogger.setLevel(level=UNSTRACT_TO_PY_LOG_LEVEL[self.log_level]) + def stream_log( self, log: str, diff --git a/src/unstract/sdk/utils/common_utils.py b/src/unstract/sdk/utils/common_utils.py index 568b9e2c..f2866c00 100644 --- a/src/unstract/sdk/utils/common_utils.py +++ b/src/unstract/sdk/utils/common_utils.py @@ -23,6 +23,14 @@ def generate_uuid() -> str: logging.ERROR: LogLevel.ERROR, } +# Mapping from Unstract log level to python counterpart +UNSTRACT_TO_PY_LOG_LEVEL = { + LogLevel.DEBUG: logging.DEBUG, + LogLevel.INFO: logging.INFO, + LogLevel.WARN: logging.WARNING, + LogLevel.ERROR: logging.ERROR, +} + def log_elapsed(operation): """Adds an elapsed time log.