Skip to content

Commit

Permalink
Dw/split run (#233)
Browse files Browse the repository at this point in the history
* split run method into init and run_test_gen methods

* update version
  • Loading branch information
qododavid authored Nov 22, 2024
1 parent f768030 commit d40cbed
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
52 changes: 33 additions & 19 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
import wandb

from typing import List

from cover_agent.CustomLogger import CustomLogger
from cover_agent.PromptBuilder import adapt_test_command_for_a_single_test_via_ai
from cover_agent.ReportGenerator import ReportGenerator
Expand Down Expand Up @@ -130,26 +132,14 @@ def _duplicate_test_file(self):
# Otherwise, set the test file output path to the current test file
self.args.test_file_output_path = self.args.test_file_path

def run(self):
def init(self):
"""
Run the test generation process.
This method performs the following steps:
Prepare for test generation process
1. Initialize the Weights & Biases run if the WANDS_API_KEY environment variable is set.
2. Initialize variables to track progress.
3. Run the initial test suite analysis.
4. Loop until desired coverage is reached or maximum iterations are met.
5. Generate new tests.
6. Loop through each new test and validate it.
7. Insert the test result into the database.
8. Increment the iteration count.
9. Check if the desired coverage has been reached.
10. If the desired coverage has been reached, log the final coverage.
11. If the maximum iteration limit is reached, log a failure message if strict coverage is specified.
12. Provide metrics on total token usage.
13. Generate a report.
14. Finish the Weights & Biases run if it was initialized.
"""
# Check if user has exported the WANDS_API_KEY environment variable
if "WANDB_API_KEY" in os.environ:
Expand All @@ -159,15 +149,35 @@ def run(self):
run_name = f"{self.args.model}_" + time_and_date
wandb.init(project="cover-agent", name=run_name)

# Initialize variables to track progress
iteration_count = 0
test_results_list = []

# Run initial test suite analysis
self.test_validator.initial_test_suite_analysis()
failed_test_runs, language, test_framework, coverage_report = self.test_validator.get_coverage()
self.test_gen.build_prompt(failed_test_runs, language, test_framework, coverage_report)

return failed_test_runs, language, test_framework, coverage_report

def run_test_gen(self, failed_test_runs: List, language: str, test_framework: str, coverage_report: str):
"""
Run the test generation process.
This method performs the following steps:
1. Loop until desired coverage is reached or maximum iterations are met.
2. Generate new tests.
3. Loop through each new test and validate it.
4. Insert the test result into the database.
5. Increment the iteration count.
6. Check if the desired coverage has been reached.
7. If the desired coverage has been reached, log the final coverage.
8. If the maximum iteration limit is reached, log a failure message if strict coverage is specified.
9. Provide metrics on total token usage.
10. Generate a report.
11. Finish the Weights & Biases run if it was initialized.
"""
# Initialize variables to track progress
iteration_count = 0
test_results_list = []

# Loop until desired coverage is reached or maximum iterations are met
while (
self.test_validator.current_coverage < (self.test_validator.desired_coverage / 100)
Expand Down Expand Up @@ -240,3 +250,7 @@ def run(self):
# Finish the Weights & Biases run if it was initialized
if "WANDB_API_KEY" in os.environ:
wandb.finish()

def run(self):
failed_test_runs, language, test_framework, coverage_report = self.init()
self.run_test_gen(failed_test_runs, language, test_framework, coverage_report)
4 changes: 2 additions & 2 deletions cover_agent/UnitTestValidator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_coverage(self):
self.run_coverage()
return self.failed_test_runs, self.language, self.testing_framework, self.code_coverage_report

def get_code_language(self, source_file_path):
def get_code_language(self, source_file_path: str) -> str:
"""
Get the programming language based on the file extension of the provided source file path.
Expand Down Expand Up @@ -251,7 +251,7 @@ def initial_test_suite_analysis(self):
relevant_line_number_to_insert_imports_after = tests_dict.get(
"relevant_line_number_to_insert_imports_after", None
)
self.testing_framework = tests_dict.get("testing_framework", "Unknown")
self.testing_framework: str = tests_dict.get("testing_framework", "Unknown")
counter_attempts += 1

if not relevant_line_number_to_insert_tests_after:
Expand Down
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.7
0.2.8

0 comments on commit d40cbed

Please sign in to comment.