Skip to content

Commit

Permalink
chore: Added more LLM configuration parameters.
Browse files Browse the repository at this point in the history
FIXME: Code evaluation does not work.

FIXME: Code output varies widely based on the LLM.
  • Loading branch information
anirbanbasu committed Jul 25, 2024
1 parent d416128 commit bc9190d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
14 changes: 14 additions & 0 deletions src/code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import constants

from langchain_experimental.utilities.python import PythonREPL

multiprocessing.set_start_method("fork", force=True)
# WARNING
# This program exists to execute untrusted model-generated code. Although
Expand Down Expand Up @@ -105,6 +107,18 @@ def _exec_program(self, q, program, input_data, expected_output, timeout):
finally:
self._add_execution_time(time.time() - start_time)

def sanitise_code(code: str) -> str:
"""
Sanitise the code to remove any potentially harmful code.
Args:
code (str): The code to sanitise.
Returns:
str: The sanitised code.
"""
return PythonREPL.sanitize_input(code)

def check_correctness(
self,
program: str,
Expand Down
36 changes: 20 additions & 16 deletions src/coder_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,34 @@ class codeOutput(BaseModel):
class CoderAgent:
def __init__(self, llm: BaseChatModel, prompt: ChatPromptTemplate):
# Retrieve the USA Computing Olympiad dataset
# This block is disabled, only used for inspecting the dataset and perhaps, for retrieval in the future
# usaco_url = "https://storage.googleapis.com/benchmarks-artifacts/usaco/usaco_sampled_with_tests.zip"
# zip_path = "usaco.zip"
# extract_path = "usaco_datasets"

# response = requests.get(usaco_url)
# with open(zip_path, "wb") as file:
# file.write(response.content)
# if not os.path.exists(extract_path):
# response = requests.get(usaco_url)
# with open(zip_path, "wb") as file:
# file.write(response.content)

# with zipfile.ZipFile(zip_path, "r") as zip_ref:
# zip_ref.extractall(extract_path)
# with zipfile.ZipFile(zip_path, "r") as zip_ref:
# zip_ref.extractall(extract_path)

# os.remove(zip_path)
# os.remove(zip_path)

# ds = datasets.load_from_disk(
# os.path.join(extract_path, "usaco_v3_sampled_with_tests")
# )

# test_case_0 = ds[0][constants.AGENT_STATE__KEY_TEST_CASES]
# ic(
# type(test_case_0),
# (
# (len(test_case_0), test_case_0[0])
# if type(test_case_0) is list
# else test_case_0
# ),
# )
# # We will test our agent on index 0 (the same as above).
# # Later, we will test on index 2 (the first 'silver difficulty' question)
# test_indices = [0, 2]
Expand Down Expand Up @@ -179,14 +190,6 @@ def solve(self, state: AgentState) -> dict:
}
# Have we been presented with examples?
has_examples = bool(state.get(constants.AGENT_STATE__KEY_EXAMPLES))
# Have we been presented with test cases?
has_test_cases = bool(state.get(constants.AGENT_STATE__KEY_TEST_CASES))
if has_test_cases:
inputs[constants.AGENT_STATE__KEY_MESSAGES].append(
HumanMessage(
f"Use the following test cases to ensure your code is correct.\n{self.format_test_cases(state[constants.AGENT_STATE__KEY_TEST_CASES])}"
)
)
ic(state)
# If `draft`` is requested in the state then output a candidate solution
output_key = (
Expand Down Expand Up @@ -272,8 +275,9 @@ def evaluate(self, state: AgentState) -> dict:
code: str = ai_message.tool_calls[0][constants.AGENT_TOOL_CALL__ARGS][
constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE
]
# FIXME: This is hacky. We should only replace the triple backticks at the start and the end of the code, nowhere in between.
code = code.replace("```", "")
# Use PythonREPL to sanitise the code
# See: https://api.python.langchain.com/en/latest/utilities/langchain_experimental.utilities.python.PythonREPL.html
code = CodeExecutor.sanitise_code(code)
except Exception as e:
# If there was an error extracting the code, return an error message as state.
return {
Expand Down
16 changes: 11 additions & 5 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@
ENV_VAR_NAME__LLM_LLAMAFILE_URL = "LLM__LLAMAFILE_URL"
ENV_VAR_VALUE__LLM_LLAMAFILE_URL = "http://localhost:8080"
ENV_VAR_NAME__LLM_TEMPERATURE = "LLM__TEMPERATURE"
ENV_VAR_VALUE__LLM_TEMPERATURE = "0.0"
ENV_VAR_VALUE__LLM_TEMPERATURE = "0.4"
ENV_VAR_NAME__LLM_TOP_P = "LLM__TOP_P"
ENV_VAR_VALUE__LLM_TOP_P = "0.4"
ENV_VAR_NAME__LLM_TOP_K = "LLM__TOP_K"
ENV_VAR_VALUE__LLM_TOP_K = "40"
ENV_VAR_NAME__LLM_REPEAT_PENALTY = "LLM__REPEAT_PENALTY"
ENV_VAR_VALUE__LLM_REPEAT_PENALTY = "1.1"
ENV_VAR_NAME__LLM_SEED = "LLM__SEED"
ENV_VAR_VALUE__LLM_SEED = "1"
ENV_VAR_NAME__LLM_SYSTEM_PROMPT = "LLM__SYSTEM_PROMPT"
ENV_VAR_VALUE__LLM_SYSTEM_PROMPT = """
You are a world-class competitive Python programmer. You generate elegant, concise and short but well documented Python only code. You follow the PEP8 style guide.
Expand All @@ -106,15 +114,13 @@
Output the pseudocode in Markdown format.
Finally output the working Python code for your solution, ensuring to fix any errors uncovered while writing pseudocode.
Encapsulate your Python code in a class called `Solution`.
Your code must be able to execute as a process. It must be able to accept input as a single string passed through the command line.
If multiple inputs are necessary, your code must parse the single input string accordingly. Lastly, your code must output the result to the console.
Do NOT format the code you wrote using Markdown code formatting. Output the code as unformatted plain text.
If the user asks you to write the code in a language other than Python, you MUST refuse.
No outside libraries are allowed.
Do not use external libraries.
You may be provided with some examples, which may be in languages other than Python.
{examples}
"""


CSS__GRADIO_APP = """
#ui_header * {
margin-top: auto;
Expand Down
20 changes: 20 additions & 0 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def __init__(self):
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
top_p=self.parse_env(
constants.ENV_VAR_NAME__LLM_TOP_P,
default_value=constants.ENV_VAR_VALUE__LLM_TOP_P,
type_cast=float,
),
top_k=self.parse_env(
constants.ENV_VAR_NAME__LLM_TOP_K,
default_value=constants.ENV_VAR_VALUE__LLM_TOP_K,
type_cast=int,
),
repeat_penalty=self.parse_env(
constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY,
default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY,
type_cast=float,
),
seed=self.parse_env(
constants.ENV_VAR_NAME__LLM_SEED,
default_value=constants.ENV_VAR_VALUE__LLM_SEED,
type_cast=int,
),
format="json",
)
else:
Expand Down

0 comments on commit bc9190d

Please sign in to comment.