diff --git a/src/coverup/segment.py b/src/coverup/segment.py index 595e4ba..e4dd6b5 100644 --- a/src/coverup/segment.py +++ b/src/coverup/segment.py @@ -1,6 +1,7 @@ import typing as T from pathlib import Path from .utils import * +import ast class CodeSegment: @@ -11,7 +12,8 @@ def __init__(self, filename: Path, name: str, begin: int, end: int, missing_lines: T.Set[int], executed_lines: T.Set[int], missing_branches: T.Set[T.Tuple[int, int]], - context: T.List[T.Tuple[int, int]]): + context: T.List[T.Tuple[int, int]], + imports: T.List[str]): self.path = Path(filename).resolve() self.filename = filename self.name = name @@ -22,12 +24,12 @@ def __init__(self, filename: Path, name: str, begin: int, end: int, self.executed_lines = executed_lines self.missing_branches = missing_branches self.context = context + self.imports = imports def __repr__(self): return f"CodeSegment(\"{self.filename}\", \"{self.name}\", {self.begin}, {self.end}, " + \ f"{self.missing_lines}, {self.executed_lines}, {self.missing_branches}, {self.context})" - def identify(self) -> str: return f"{self.filename}:{self.begin}-{self.end-1}" @@ -39,6 +41,9 @@ def get_excerpt(self, tag_lines=True): with open(self.filename, "r") as src: code = src.readlines() + for imp in self.imports: + excerpt.extend([f"{'':10} {imp}\n"]) + for b, e in self.context: for i in range(b, e): excerpt.extend([f"{'':10} ", code[i-1]]) @@ -60,12 +65,46 @@ def missing_count(self) -> int: return len(self.missing_lines)+len(self.missing_branches) +def get_global_imports(tree, node): + def get_names(node): + # TODO this ignores numerous ways in which a global import might not be visible, + # such as when local symbols are created, etc. In such cases, showing the + # import in the excerpt is extraneous, but not incorrect. + for n in ast.walk(node): + if isinstance(n, ast.Name): + yield n.id + + def get_imports(n): + # TODO imports inside Class defines name in the class' namespace; they are uncommon + # Imports within functions are included in the excerpt, so there's no need for us + # to find them. + if not isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): + for child in ast.iter_child_nodes(n): + yield from get_imports(child) + + if isinstance(n, (ast.Import, ast.ImportFrom)): + names = set(alias.asname if alias.asname else alias.name for alias in n.names) + yield names, n + + # Process imports reversed so that the 1st import to define a name wins in the set; + # it's the programmer's first choice (if within a `try`, for example) + imap = {name: imp for imp in reversed(list(get_imports(tree))) for name in imp[0]} + names = set(get_names(node)) + + imports = [] + while names: + name = names.pop() + if imp := imap.get(name): + imports.append(ast.unparse(imp[1])) + names -= imp[0] + + return imports + def get_missing_coverage(coverage, line_limit: int = 100) -> T.List[CodeSegment]: """Processes a JSON SlipCover output and generates a list of Python code segments, such as functions or classes, which have less than 100% coverage. """ - import ast code_segs = [] @@ -128,17 +167,20 @@ def find_enclosing(root, line): assert line < end assert (begin, end) not in line_ranges - line_ranges[(begin, end)] = (node, context) + line_ranges[(begin, end)] = (node, context, get_global_imports(tree, node)) lines_in_segments.update({*range(begin, end)}) if line_ranges: - for (begin, end), (node, context) in line_ranges.items(): + for (begin, end), (node, context, imports) in line_ranges.items(): line_range_set = {*range(begin, end)} - code_segs.append(CodeSegment(fname, node.name, begin, end, - lines_of_interest=lines_of_interest.intersection(line_range_set), - missing_lines=missing_lines.intersection(line_range_set), - executed_lines=executed_lines.intersection(line_range_set), - missing_branches={tuple(b) for b in missing_branches if b[0] in line_range_set}, - context=context)) + code_segs.append(CodeSegment( + fname, node.name, begin, end, + lines_of_interest=lines_of_interest.intersection(line_range_set), + missing_lines=missing_lines.intersection(line_range_set), + executed_lines=executed_lines.intersection(line_range_set), + missing_branches={tuple(b) for b in missing_branches if b[0] in line_range_set}, + context=context, + imports=imports) + ) return code_segs diff --git a/tests/test_coverup_22.py b/tests/test_coverup_22.py index 6ff09f9..6632783 100644 --- a/tests/test_coverup_22.py +++ b/tests/test_coverup_22.py @@ -17,7 +17,7 @@ def code_segment(): missing_lines=set(), executed_lines=set(), missing_branches=set(), - context={} + context=[], imports=[] ) @pytest.fixture diff --git a/tests/test_coverup_26.py b/tests/test_coverup_26.py index a772689..e1d353c 100644 --- a/tests/test_coverup_26.py +++ b/tests/test_coverup_26.py @@ -19,7 +19,7 @@ def test_missing_count(): missing_lines={1, 2, 3}, executed_lines=set(range(4, 11)), missing_branches={(5, True), (6, False)}, - context=None + context=[], imports=[] ) assert segment.missing_count() == 5 @@ -34,7 +34,7 @@ def test_missing_count_zero(): missing_lines=set(), executed_lines=set(range(1, 11)), missing_branches=set(), - context=None + context=[], imports=[] ) assert segment.missing_count() == 0 @@ -49,7 +49,7 @@ def test_missing_count_only_lines(): missing_lines={1, 2, 3}, executed_lines=set(range(4, 11)), missing_branches=set(), - context=None + context=[], imports=[] ) assert segment.missing_count() == 3 @@ -64,6 +64,6 @@ def test_missing_count_only_branches(): missing_lines=set(), executed_lines=set(range(1, 11)), missing_branches={(5, True), (6, False)}, - context=None + context=[], imports=[] ) assert segment.missing_count() == 2 diff --git a/tests/test_openai_coverup_12.py b/tests/test_openai_coverup_12.py index 7a183dd..c0f3913 100644 --- a/tests/test_openai_coverup_12.py +++ b/tests/test_openai_coverup_12.py @@ -16,7 +16,7 @@ def code_segment(): missing_lines=set(), executed_lines=set(), missing_branches=set(), - context={} + context=[], imports=[] ) return segment diff --git a/tests/test_openai_coverup_13.py b/tests/test_openai_coverup_13.py index e8f059f..4b0435f 100644 --- a/tests/test_openai_coverup_13.py +++ b/tests/test_openai_coverup_13.py @@ -17,7 +17,7 @@ def code_segment(): missing_lines=set(), executed_lines=set(), missing_branches=set(), - context=[] + context=[], imports=[] ) return cs diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 189167c..c54e473 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -27,7 +27,7 @@ def test_gpt4_v1_relative_file_name(tmp_path, monkeypatch): name="foo", begin=10, end=20, lines_of_interest=set(), missing_lines=set(), executed_lines=set(), missing_branches=set(), - context=[] + context=[], imports=[] ) monkeypatch.setattr(CodeSegment, "get_excerpt", lambda self, tag_lines=True: '') @@ -55,7 +55,7 @@ def test_claude_relative_file_name(tmp_path, monkeypatch): name="foo", begin=10, end=20, lines_of_interest=set(), missing_lines=set(), executed_lines=set(), missing_branches=set(), - context=[] + context=[], imports=[] ) monkeypatch.setattr(CodeSegment, "get_excerpt", lambda self, tag_lines=True: '') diff --git a/tests/test_segment.py b/tests/test_segment.py index 07160da..438f48a 100644 --- a/tests/test_segment.py +++ b/tests/test_segment.py @@ -2,6 +2,45 @@ import coverup.segment as segment import textwrap import json +import ast + + +def test_get_global_imports(): + code = textwrap.dedent("""\ + import a, b + from c import d as e + import os + + try: + from hashlib import sha1 + except ImportError: + from sha import sha as sha1 + + class Foo: + import abc as os + + def f(): + if os.path.exists("foo"): + sha1(a.x, b.x, e.x) + """) + + print(code) + + def find_node(tree, name): + for n in ast.walk(tree): + if isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) and n.name == name: + return n + + tree = ast.parse(code) +# print(ast.dump(tree, indent=2)) + + f = find_node(tree, 'f') + assert set(segment.get_global_imports(tree, f)) == { + 'import a, b', + 'from hashlib import sha1', + 'from c import d as e', + 'import os' + } class mockfs: @@ -27,14 +66,14 @@ def __exit__(self, *args): somecode_py = textwrap.dedent("""\ # Sample Python code used to create some tests. - import sys + import os class Foo: '''docstring...''' @staticmethod def foo(): - pass + return os.path.exists('here') def __init__(self): '''initializes...''' @@ -103,12 +142,13 @@ def test_large_limit_whole_class(): assert all(seg.begin < seg.end for seg in segs) assert textwrap.dedent(segs[0].get_excerpt(tag_lines=False)) == textwrap.dedent("""\ + import os class Foo: '''docstring...''' @staticmethod def foo(): - pass + return os.path.exists('here') def __init__(self): '''initializes...''' @@ -153,10 +193,11 @@ def test_small_limit(): assert textwrap.dedent(segs[0].get_excerpt(tag_lines=False)) == textwrap.dedent("""\ + import os class Foo: @staticmethod def foo(): - pass + return os.path.exists('here') """) assert textwrap.dedent(segs[2].get_excerpt(tag_lines=False)) == textwrap.dedent("""\ @@ -262,3 +303,4 @@ def __init__(self, x): self.x = x 7: self.y = 0 """) +