Skip to content

Commit

Permalink
160 added test framework and proper error messaging in prompt (#165)
Browse files Browse the repository at this point in the history
* Added testing framework to prompt.

* Cleaned up code and added TO DO.

* Added prompt template for failure analysis.

* Added failed test summary from LLM.

* Fixed parenthesis.

* Added TOML to Makefile.
  • Loading branch information
EmbeddedDevops1 authored Sep 26, 2024
1 parent 5264bb9 commit 043c901
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 68 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ installer:
--add-data "cover_agent/settings/test_generation_prompt.toml:." \
--add-data "cover_agent/settings/analyze_suite_test_headers_indentation.toml:." \
--add-data "cover_agent/settings/analyze_suite_test_insert_line.toml:." \
--add-data "cover_agent/settings/analyze_test_run_failure.toml:." \
--add-data "$(SITE_PACKAGES)/vendor:wandb/vendor" \
--hidden-import=tiktoken_ext.openai_public \
--hidden-import=tiktoken_ext \
Expand Down
10 changes: 5 additions & 5 deletions cover_agent/AICaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, model: str, api_base: str = ""):
self.model = model
self.api_base = api_base

def call_model(self, prompt: dict, max_tokens=4096):
def call_model(self, prompt: dict, max_tokens=4096, stream=True):
"""
Call the language model with the provided prompt and retrieve the response.
Expand Down Expand Up @@ -61,17 +61,17 @@ def call_model(self, prompt: dict, max_tokens=4096):
response = litellm.completion(**completion_params)

chunks = []
print("Streaming results from LLM model...")
print("Streaming results from LLM model...") if stream else None
try:
for chunk in response:
print(chunk.choices[0].delta.content or "", end="", flush=True)
print(chunk.choices[0].delta.content or "", end="", flush=True) if stream else None
chunks.append(chunk)
time.sleep(
0.01
) # Optional: Delay to simulate more 'natural' response pacing
except Exception as e:
print(f"Error during streaming: {e}")
print("\n")
print(f"Error during streaming: {e}") if stream else None
print("\n") if stream else None

model_response = litellm.stream_chunk_builder(chunks, messages=messages)

Expand Down
35 changes: 29 additions & 6 deletions cover_agent/PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
additional_instructions: str = "",
failed_test_runs: str = "",
language: str = "python",
testing_framework: str = "NOT KNOWN",
):
"""
The `PromptBuilder` class is responsible for building a formatted prompt string by replacing placeholders with the actual content of files read during initialization. It takes in various paths and settings as parameters and provides a method to generate the prompt.
Expand Down Expand Up @@ -72,6 +73,8 @@ def __init__(
self.test_file = self._read_file(test_file_path)
self.code_coverage_report = code_coverage_report
self.language = language
self.testing_framework = testing_framework

# add line numbers to each line in 'source_file'. start from 1
self.source_file_numbered = "\n".join(
[f"{i + 1} {line}" for i, line in enumerate(self.source_file.split("\n"))]
Expand Down Expand Up @@ -99,6 +102,9 @@ def __init__(
else ""
)

self.stdout_from_run = ""
self.stderr_from_run = ""

def _read_file(self, file_path):
"""
Helper method to read file contents.
Expand Down Expand Up @@ -138,6 +144,9 @@ def build_prompt(self) -> dict:
"additional_instructions_text": self.additional_instructions,
"language": self.language,
"max_tests": MAX_TESTS_PER_RUN,
"testing_framework": self.testing_framework,
"stdout": self.stdout_from_run,
"stderr": self.stderr_from_run,
}
environment = Environment(undefined=StrictUndefined)
try:
Expand All @@ -155,6 +164,15 @@ def build_prompt(self) -> dict:
return {"system": system_prompt, "user": user_prompt}

def build_prompt_custom(self, file) -> dict:
"""
Builds a custom prompt by replacing placeholders with actual content from files and settings.
Parameters:
file (str): The file to retrieve settings for building the prompt.
Returns:
dict: A dictionary containing the system and user prompts.
"""
variables = {
"source_file_name": self.source_file_name,
"test_file_name": self.test_file_name,
Expand All @@ -168,15 +186,20 @@ def build_prompt_custom(self, file) -> dict:
"additional_instructions_text": self.additional_instructions,
"language": self.language,
"max_tests": MAX_TESTS_PER_RUN,
"testing_framework": self.testing_framework,
"stdout": self.stdout_from_run,
"stderr": self.stderr_from_run,
}
environment = Environment(undefined=StrictUndefined)
try:
system_prompt = environment.from_string(
get_settings().get(file).system
).render(variables)
user_prompt = environment.from_string(get_settings().get(file).user).render(
variables
)
settings = get_settings().get(file)
if settings is None or not hasattr(settings, "system") or not hasattr(
settings, "user"
):
logging.error(f"Could not find settings for prompt file: {file}")
return {"system": "", "user": ""}
system_prompt = environment.from_string(settings.system).render(variables)
user_prompt = environment.from_string(settings.user).render(variables)
except Exception as e:
logging.error(f"Error rendering prompt: {e}")
return {"system": "", "user": ""}
Expand Down
83 changes: 48 additions & 35 deletions cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self.failed_test_runs = []
self.total_input_token_count = 0
self.total_output_token_count = 0
self.testing_framework = "Unknown"

# Read self.source_file_path into a string
with open(self.source_file_path, "r") as f:
Expand Down Expand Up @@ -269,10 +270,7 @@ def build_prompt(self):
continue
# dump dict to str
code = json.dumps(failed_test_dict)
if "error_message" in failed_test:
error_message = failed_test["error_message"]
else:
error_message = None
error_message = failed_test.get("error_message", None)
failed_test_runs_value += f"Failed Test:\n```\n{code}\n```\n"
if error_message:
failed_test_runs_value += (
Expand All @@ -296,6 +294,7 @@ def build_prompt(self):
additional_instructions=self.additional_instructions,
failed_test_runs=failed_test_runs_value,
language=self.language,
testing_framework=self.testing_framework,
)

return self.prompt_builder.build_prompt()
Expand Down Expand Up @@ -363,6 +362,7 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_imports_after = tests_dict.get(
"relevant_line_number_to_insert_imports_after", None
)
self.testing_framework = tests_dict.get("testing_framework", "Unknown")
counter_attempts += 1

if not relevant_line_number_to_insert_tests_after:
Expand Down Expand Up @@ -562,9 +562,9 @@ def validate_test(self, generated_test: dict, num_attempts=1):
"processed_test_file": processed_test,
}

error_message = extract_error_message_python(fail_details["stdout"])
error_message = self.extract_error_message(stderr=fail_details["stderr"], stdout=fail_details["stdout"])
if error_message:
logging.error(f"Error message:\n{error_message}")
logging.error(f"Error message summary:\n{error_message}")

self.failed_test_runs.append(
{"code": generated_test, "error_message": error_message}
Expand Down Expand Up @@ -647,7 +647,7 @@ def validate_test(self, generated_test: dict, num_attempts=1):
self.failed_test_runs.append(
{
"code": fail_details["test"],
"error_message": "did not increase code coverage",
"error_message": "Code coverage did not increase",
}
) # Append failure details to the list

Expand Down Expand Up @@ -686,7 +686,7 @@ def validate_test(self, generated_test: dict, num_attempts=1):
self.failed_test_runs.append(
{
"code": fail_details["test"],
"error_message": "coverage verification error",
"error_message": "Coverage verification error",
}
) # Append failure details to the list
return fail_details
Expand Down Expand Up @@ -762,30 +762,43 @@ def to_json(self):
return json.dumps(self.to_dict())


def extract_error_message_python(fail_message):
"""
Extracts and returns the error message from the provided failure message.
Parameters:
fail_message (str): The failure message containing the error message to be extracted.
Returns:
str: The extracted error message from the failure message, or an empty string if no error message is found.
"""
try:
# Define a regular expression pattern to match the error message
MAX_LINES = 20
pattern = r"={3,} FAILURES ={3,}(.*?)(={3,}|$)"
match = re.search(pattern, fail_message, re.DOTALL)
if match:
err_str = match.group(1).strip("\n")
err_str_lines = err_str.split("\n")
if len(err_str_lines) > MAX_LINES:
# show last MAX_lines lines
err_str = "...\n" + "\n".join(err_str_lines[-MAX_LINES:])
return err_str
return ""
except Exception as e:
logging.error(f"Error extracting error message: {e}")
return ""
def extract_error_message(self, stderr, stdout):
"""
Extracts the error message from the provided stderr and stdout outputs.
Updates the PromptBuilder object with the stderr and stdout, builds a custom prompt for analyzing test run failures,
calls the language model to analyze the prompt, and loads the response into a dictionary.
Returns the error summary from the loaded YAML data or a default error message if unable to summarize.
Logs errors encountered during the process.
Parameters:
stderr (str): The standard error output from the test run.
stdout (str): The standard output from the test run.
Returns:
str: The error summary extracted from the response or a default error message if extraction fails.
"""
try:
# Update the PromptBuilder object with stderr and stdout
self.prompt_builder.stderr_from_run = stderr
self.prompt_builder.stdout_from_run = stdout

# Build the prompt
prompt_headers_indentation = self.prompt_builder.build_prompt_custom(
file="analyze_test_run_failure"
)

# Run the analysis via LLM
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_headers_indentation, stream=False)
)
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)

return tests_dict.get("error_summary", f"ERROR: Unable to summarize error message from inputs. STDERR: {stderr}\nSTDOUT: {stdout}.")
except Exception as e:
logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}\n\n{response}")
logging.error(f"Error extracting error message: {e}")
return ""
41 changes: 41 additions & 0 deletions cover_agent/settings/analyze_test_run_failure.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[analyze_test_run_failure]
system="""\
"""

user="""\
## Overview
You are a code assistant that accepts both the stdout and stderr from a test run, specifically for unit test regression testing.
Your goal is to analyze the output, and summarize the failure for further analysis.
Please provide a one-sentence summary of the error, including the following details:
- The offending line of code (if available).
- The line number where the error occurred.
- Any other relevant details or information gleaned from the stdout and stderr.
Here is the stdout and stderr from the test run:
=========
stdout:
{{ stdout|trim }}
=========
stderr:
=========
{{ stderr|trim }}
=========
Now, you need to analyze the output and provide a YAML object equivalent to type $TestFailureAnalysis, according to the following Pydantic definitions:
=====
class TestFailureAnalysis(BaseModel):
error_summary: str = Field(description="A one-sentence summary of the failure, including the offending line of code, line number, and other relevant information from the stdout/stderr.")
=====
Example output:
```yaml
error_summary: ...
```
The Response should be only a valid YAML object, without any introduction text or follow-up text.
Answer:
```yaml
"""
1 change: 1 addition & 0 deletions cover_agent/settings/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"language_extensions.toml",
"analyze_suite_test_headers_indentation.toml",
"analyze_suite_test_insert_line.toml",
"analyze_test_run_failure.toml",
]


Expand Down
2 changes: 2 additions & 0 deletions cover_agent/settings/test_generation_prompt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Here is the file that contains the existing tests, called `{{ test_file_name }}`
{{ test_file| trim }}
=========
### Test Framework
The test framework used for running tests is `{{ testing_framework }}`.
{%- if additional_includes_section|trim %}
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.50
0.1.51
33 changes: 33 additions & 0 deletions tests/test_PromptBuilder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import pytest
import tempfile
from unittest.mock import patch, mock_open
from cover_agent.PromptBuilder import PromptBuilder

Expand Down Expand Up @@ -172,3 +174,34 @@ def mock_render(*args, **kwargs):
)
result = builder.build_prompt()
assert result == {"system": "", "user": ""}

class TestPromptBuilderEndToEnd:
def test_custom_analyze_test_run_failure(self):
# Create fake source and test files and tmp files and pass in the paths
source_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
source_file.write("def foo():\n pass")
source_file.close()
test_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
test_file.write("def test_foo():\n pass")
test_file.close()
tmp_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
tmp_file.write("tmp file content")
tmp_file.close()

builder = PromptBuilder(
source_file_path=source_file.name,
test_file_path=test_file.name,
code_coverage_report=tmp_file.name,
)

builder.stderr_from_run = "stderr content"
builder.stdout_from_run = "stdout content"

result = builder.build_prompt_custom("analyze_test_run_failure")
assert "stderr content" in result["user"]
assert "stdout content" in result["user"]

# Clean up
os.remove(source_file.name)
os.remove(test_file.name)
os.remove(tmp_file.name)
19 changes: 1 addition & 18 deletions tests/test_UnitTestGenerator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import pytest
from cover_agent.UnitTestGenerator import (
UnitTestGenerator,
extract_error_message_python,
)
from cover_agent.UnitTestGenerator import UnitTestGenerator
from cover_agent.ReportGenerator import ReportGenerator
import os

Expand Down Expand Up @@ -31,17 +28,3 @@ def test_get_included_files_valid_paths(self):
result
== "file_path: `file1.txt`\ncontent:\n```\nfile content\n```\nfile_path: `file2.txt`\ncontent:\n```\nfile content\n```"
)


class TestExtractErrorMessage:
def test_extract_single_match(self):
fail_message = "=== FAILURES ===\\nError occurred here\\n=== END ==="
expected = "\\nError occurred here\\n"
result = extract_error_message_python(fail_message)
assert result == expected, f"Expected '{expected}', got '{result}'"

def test_extract_bad_match(self):
fail_message = 33
expected = ""
result = extract_error_message_python(fail_message)
assert result == expected, f"Expected '{expected}', got '{result}'"
Loading

0 comments on commit 043c901

Please sign in to comment.