From f977a6ab16369fd7e959e9d249774621bca0a0fd Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Wed, 7 Aug 2024 22:18:38 +0900 Subject: [PATCH] fix: Fixed the parse_env function. --- src/utils.py | 60 +++++++++++++++++++++++++++++++ src/webapp.py | 98 +++++++++++++++------------------------------------ 2 files changed, 88 insertions(+), 70 deletions(-) create mode 100644 src/utils.py diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..a3ab455 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,60 @@ +# Copyright 2024 Anirban Basu + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various utility functions used in the project.""" + +import os +from typing import Any +import constants + + +def parse_env( + self, + var_name: str, + default_value: str | None = None, + type_cast=str, + convert_to_list=False, + list_split_char=constants.SPACE_STRING, +) -> Any | list[Any]: + """ + Parse the environment variable and return the value. + + Args: + var_name (str): The name of the environment variable. + default_value (str | None): The default value to use if the environment variable is not set. Defaults to None. + type_cast (str): The type to cast the value to. + convert_to_list (bool): Whether to convert the value to a list. + list_split_char (str): The character to split the list on. + + Returns: + (Any | list[Any]) The parsed value, either as a single value or a list. The type of the returned single + value or individual elements in the list depends on the supplied type_cast parameter. + """ + if os.getenv(var_name) is None and default_value is None: + raise ValueError( + f"Environment variable {var_name} does not exist and a default value has not been provided." + ) + parsed_value = None + if type_cast is bool: + parsed_value = ( + os.getenv(var_name, default_value).lower() in constants.TRUE_VALUES_LIST + ) + else: + parsed_value = os.getenv(var_name, default_value) + + value: Any | list[Any] = ( + type_cast(parsed_value) + if not convert_to_list + else [type_cast(v) for v in parsed_value.split(list_split_char)] + ) + return value diff --git a/src/webapp.py b/src/webapp.py index bcc06d1..dbdbccc 100644 --- a/src/webapp.py +++ b/src/webapp.py @@ -13,11 +13,10 @@ # limitations under the License. """The main web application module for the Gradio app.""" -import os from dotenv import load_dotenv -from typing import Any from coder_agent import CoderAgent, TestCase +from utils import parse_env try: from icecream import ic @@ -57,51 +56,51 @@ class GradioApp: def __init__(self): """Default constructor for the Gradio app.""" ic(load_dotenv()) - self._gradio_host: str = self.parse_env( + self._gradio_host: str = parse_env( constants.ENV_VAR_NAME__GRADIO_SERVER_HOST, default_value=constants.ENV_VAR_VALUE__GRADIO_SERVER_HOST, ) - self._gradio_port: int = self.parse_env( + self._gradio_port: int = parse_env( constants.ENV_VAR_NAME__GRADIO_SERVER_PORT, default_value=constants.ENV_VAR_VALUE__GRADIO_SERVER_PORT, type_cast=int, ) ic(self._gradio_host, self._gradio_port) - self._llm_provider = self.parse_env( + self._llm_provider = parse_env( constants.ENV_VAR_NAME__LLM_PROVIDER, default_value=constants.ENV_VAR_VALUE__LLM_PROVIDER, ) if self._llm_provider == "Ollama": self._llm = ChatOllama( - base_url=self.parse_env( + base_url=parse_env( constants.ENV_VAR_NAME__LLM_OLLAMA_URL, default_value=constants.ENV_VAR_VALUE__LLM_OLLAMA_URL, ), - model=self.parse_env( + model=parse_env( constants.ENV_VAR_NAME__LLM_OLLAMA_MODEL, default_value=constants.ENV_VAR_VALUE__LLM_OLLAMA_MODEL, ), - temperature=self.parse_env( + temperature=parse_env( constants.ENV_VAR_NAME__LLM_TEMPERATURE, default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE, type_cast=float, ), - top_p=self.parse_env( + top_p=parse_env( constants.ENV_VAR_NAME__LLM_TOP_P, default_value=constants.ENV_VAR_VALUE__LLM_TOP_P, type_cast=float, ), - top_k=self.parse_env( + top_k=parse_env( constants.ENV_VAR_NAME__LLM_TOP_K, default_value=constants.ENV_VAR_VALUE__LLM_TOP_K, type_cast=int, ), - repeat_penalty=self.parse_env( + repeat_penalty=parse_env( constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY, default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY, type_cast=float, ), - seed=self.parse_env( + seed=parse_env( constants.ENV_VAR_NAME__LLM_SEED, default_value=constants.ENV_VAR_VALUE__LLM_SEED, type_cast=int, @@ -110,33 +109,33 @@ def __init__(self): ) elif self._llm_provider == "Groq": self._llm = ChatGroq( - api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_GROQ_API_KEY), - model=self.parse_env( + api_key=parse_env(constants.ENV_VAR_NAME__LLM_GROQ_API_KEY), + model=parse_env( constants.ENV_VAR_NAME__LLM_GROQ_MODEL, default_value=constants.ENV_VAR_VALUE__LLM_GROQ_MODEL, ), - temperature=self.parse_env( + temperature=parse_env( constants.ENV_VAR_NAME__LLM_TEMPERATURE, default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE, type_cast=float, ), # model_kwargs={ - # "top_p": self.parse_env( + # "top_p": parse_env( # constants.ENV_VAR_NAME__LLM_TOP_P, # default_value=constants.ENV_VAR_VALUE__LLM_TOP_P, # type_cast=float, # ), - # "top_k": self.parse_env( + # "top_k": parse_env( # constants.ENV_VAR_NAME__LLM_TOP_K, # default_value=constants.ENV_VAR_VALUE__LLM_TOP_K, # type_cast=int, # ), - # "repeat_penalty": self.parse_env( + # "repeat_penalty": parse_env( # constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY, # default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY, # type_cast=float, # ), - # "seed": self.parse_env( + # "seed": parse_env( # constants.ENV_VAR_NAME__LLM_SEED, # default_value=constants.ENV_VAR_VALUE__LLM_SEED, # type_cast=int, @@ -149,12 +148,12 @@ def __init__(self): ) elif self._llm_provider == "Anthropic": self._llm = ChatAnthropic( - api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_ANTHROPIC_API_KEY), - model=self.parse_env( + api_key=parse_env(constants.ENV_VAR_NAME__LLM_ANTHROPIC_API_KEY), + model=parse_env( constants.ENV_VAR_NAME__LLM_ANTHROPIC_MODEL, default_value=constants.ENV_VAR_VALUE__LLM_ANTHROPIC_MODEL, ), - temperature=self.parse_env( + temperature=parse_env( constants.ENV_VAR_NAME__LLM_TEMPERATURE, default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE, type_cast=float, @@ -162,14 +161,12 @@ def __init__(self): ) elif self._llm_provider == "Cohere": self._llm = ChatCohere( - cohere_api_key=self.parse_env( - constants.ENV_VAR_NAME__LLM_COHERE_API_KEY - ), - model=self.parse_env( + cohere_api_key=parse_env(constants.ENV_VAR_NAME__LLM_COHERE_API_KEY), + model=parse_env( constants.ENV_VAR_NAME__LLM_COHERE_MODEL, default_value=constants.ENV_VAR_VALUE__LLM_COHERE_MODEL, ), - temperature=self.parse_env( + temperature=parse_env( constants.ENV_VAR_NAME__LLM_TEMPERATURE, default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE, type_cast=float, @@ -177,12 +174,12 @@ def __init__(self): ) elif self._llm_provider == "Open AI": self._llm = ChatOpenAI( - api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_OPENAI_API_KEY), - model=self.parse_env( + api_key=parse_env(constants.ENV_VAR_NAME__LLM_OPENAI_API_KEY), + model=parse_env( constants.ENV_VAR_NAME__LLM_OPENAI_MODEL, default_value=constants.ENV_VAR_VALUE__LLM_OPENAI_MODEL, ), - temperature=self.parse_env( + temperature=parse_env( constants.ENV_VAR_NAME__LLM_TEMPERATURE, default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE, type_cast=float, @@ -192,45 +189,6 @@ def __init__(self): raise ValueError(f"Unsupported LLM provider: {self._llm_provider}") ic(self._llm_provider, self._llm) - def parse_env( - self, - var_name: str, - default_value: str = None, - type_cast=str, - convert_to_list=False, - list_split_char=constants.SPACE_STRING, - ) -> Any | list[Any]: - """ - Parse the environment variable and return the value. - - Args: - var_name (str): The name of the environment variable. - default_value (str): The default value to use if the environment variable is not set. - type_cast (str): The type to cast the value to. - convert_to_list (bool): Whether to convert the value to a list. - list_split_char (str): The character to split the list on. - - Returns: - (Any | list[Any]) The parsed value, either as a single value or a list. The type of the returned single - value or individual elements in the list depends on the supplied type_cast parameter. - """ - if var_name not in os.environ and default_value is None: - raise ValueError(f"Environment variable {var_name} does not exist.") - parsed_value = None - if type_cast is bool: - parsed_value = ( - os.getenv(var_name, default_value).lower() in constants.TRUE_VALUES_LIST - ) - else: - parsed_value = os.getenv(var_name, default_value) - - value: Any | list[Any] = ( - type_cast(parsed_value) - if not convert_to_list - else [type_cast(v) for v in parsed_value.split(list_split_char)] - ) - return value - def find_solution( self, user_question: str, runtime_limit: int, test_cases: list[TestCase] = None ): @@ -253,7 +211,7 @@ def find_solution( }, messages=[ SystemMessagePromptTemplate.from_template( - template=self.parse_env( + template=parse_env( constants.ENV_VAR_NAME__LLM_SYSTEM_PROMPT, constants.ENV_VAR_VALUE__LLM_SYSTEM_PROMPT, )