Skip to content

Commit

Permalink
- added inclusion of global imports with segments, so that LLMs need
Browse files Browse the repository at this point in the history
  not make assumptions about where the names come from;
  • Loading branch information
jaltmayerpizzorno committed Aug 6, 2024
1 parent cc5acca commit 43a1bcf
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 24 deletions.
64 changes: 53 additions & 11 deletions src/coverup/segment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as T
from pathlib import Path
from .utils import *
import ast


class CodeSegment:
Expand All @@ -11,7 +12,8 @@ def __init__(self, filename: Path, name: str, begin: int, end: int,
missing_lines: T.Set[int],
executed_lines: T.Set[int],
missing_branches: T.Set[T.Tuple[int, int]],
context: T.List[T.Tuple[int, int]]):
context: T.List[T.Tuple[int, int]],
imports: T.List[str]):
self.path = Path(filename).resolve()
self.filename = filename
self.name = name
Expand All @@ -22,12 +24,12 @@ def __init__(self, filename: Path, name: str, begin: int, end: int,
self.executed_lines = executed_lines
self.missing_branches = missing_branches
self.context = context
self.imports = imports

def __repr__(self):
return f"CodeSegment(\"{self.filename}\", \"{self.name}\", {self.begin}, {self.end}, " + \
f"{self.missing_lines}, {self.executed_lines}, {self.missing_branches}, {self.context})"


def identify(self) -> str:
return f"{self.filename}:{self.begin}-{self.end-1}"

Expand All @@ -39,6 +41,9 @@ def get_excerpt(self, tag_lines=True):
with open(self.filename, "r") as src:
code = src.readlines()

for imp in self.imports:
excerpt.extend([f"{'':10} {imp}\n"])

for b, e in self.context:
for i in range(b, e):
excerpt.extend([f"{'':10} ", code[i-1]])
Expand All @@ -60,12 +65,46 @@ def missing_count(self) -> int:
return len(self.missing_lines)+len(self.missing_branches)


def get_global_imports(tree, node):
def get_names(node):
# TODO this ignores numerous ways in which a global import might not be visible,
# such as when local symbols are created, etc. In such cases, showing the
# import in the excerpt is extraneous, but not incorrect.
for n in ast.walk(node):
if isinstance(n, ast.Name):
yield n.id

def get_imports(n):
# TODO imports inside Class defines name in the class' namespace; they are uncommon
# Imports within functions are included in the excerpt, so there's no need for us
# to find them.
if not isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
for child in ast.iter_child_nodes(n):
yield from get_imports(child)

if isinstance(n, (ast.Import, ast.ImportFrom)):
names = set(alias.asname if alias.asname else alias.name for alias in n.names)
yield names, n

# Process imports reversed so that the 1st import to define a name wins in the set;
# it's the programmer's first choice (if within a `try`, for example)
imap = {name: imp for imp in reversed(list(get_imports(tree))) for name in imp[0]}
names = set(get_names(node))

imports = []
while names:
name = names.pop()
if imp := imap.get(name):
imports.append(ast.unparse(imp[1]))
names -= imp[0]

return imports


def get_missing_coverage(coverage, line_limit: int = 100) -> T.List[CodeSegment]:
"""Processes a JSON SlipCover output and generates a list of Python code segments,
such as functions or classes, which have less than 100% coverage.
"""
import ast

code_segs = []

Expand Down Expand Up @@ -128,17 +167,20 @@ def find_enclosing(root, line):
assert line < end
assert (begin, end) not in line_ranges

line_ranges[(begin, end)] = (node, context)
line_ranges[(begin, end)] = (node, context, get_global_imports(tree, node))
lines_in_segments.update({*range(begin, end)})

if line_ranges:
for (begin, end), (node, context) in line_ranges.items():
for (begin, end), (node, context, imports) in line_ranges.items():
line_range_set = {*range(begin, end)}
code_segs.append(CodeSegment(fname, node.name, begin, end,
lines_of_interest=lines_of_interest.intersection(line_range_set),
missing_lines=missing_lines.intersection(line_range_set),
executed_lines=executed_lines.intersection(line_range_set),
missing_branches={tuple(b) for b in missing_branches if b[0] in line_range_set},
context=context))
code_segs.append(CodeSegment(
fname, node.name, begin, end,
lines_of_interest=lines_of_interest.intersection(line_range_set),
missing_lines=missing_lines.intersection(line_range_set),
executed_lines=executed_lines.intersection(line_range_set),
missing_branches={tuple(b) for b in missing_branches if b[0] in line_range_set},
context=context,
imports=imports)
)

return code_segs
2 changes: 1 addition & 1 deletion tests/test_coverup_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def code_segment():
missing_lines=set(),
executed_lines=set(),
missing_branches=set(),
context={}
context=[], imports=[]
)

@pytest.fixture
Expand Down
8 changes: 4 additions & 4 deletions tests/test_coverup_26.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_missing_count():
missing_lines={1, 2, 3},
executed_lines=set(range(4, 11)),
missing_branches={(5, True), (6, False)},
context=None
context=[], imports=[]
)
assert segment.missing_count() == 5

Expand All @@ -34,7 +34,7 @@ def test_missing_count_zero():
missing_lines=set(),
executed_lines=set(range(1, 11)),
missing_branches=set(),
context=None
context=[], imports=[]
)
assert segment.missing_count() == 0

Expand All @@ -49,7 +49,7 @@ def test_missing_count_only_lines():
missing_lines={1, 2, 3},
executed_lines=set(range(4, 11)),
missing_branches=set(),
context=None
context=[], imports=[]
)
assert segment.missing_count() == 3

Expand All @@ -64,6 +64,6 @@ def test_missing_count_only_branches():
missing_lines=set(),
executed_lines=set(range(1, 11)),
missing_branches={(5, True), (6, False)},
context=None
context=[], imports=[]
)
assert segment.missing_count() == 2
2 changes: 1 addition & 1 deletion tests/test_openai_coverup_12.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def code_segment():
missing_lines=set(),
executed_lines=set(),
missing_branches=set(),
context={}
context=[], imports=[]
)
return segment

Expand Down
2 changes: 1 addition & 1 deletion tests/test_openai_coverup_13.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def code_segment():
missing_lines=set(),
executed_lines=set(),
missing_branches=set(),
context=[]
context=[], imports=[]
)
return cs

Expand Down
4 changes: 2 additions & 2 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_gpt4_v1_relative_file_name(tmp_path, monkeypatch):
name="foo",
begin=10, end=20,
lines_of_interest=set(), missing_lines=set(), executed_lines=set(), missing_branches=set(),
context=[]
context=[], imports=[]
)

monkeypatch.setattr(CodeSegment, "get_excerpt", lambda self, tag_lines=True: '<excerpt>')
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_claude_relative_file_name(tmp_path, monkeypatch):
name="foo",
begin=10, end=20,
lines_of_interest=set(), missing_lines=set(), executed_lines=set(), missing_branches=set(),
context=[]
context=[], imports=[]
)

monkeypatch.setattr(CodeSegment, "get_excerpt", lambda self, tag_lines=True: '<excerpt>')
Expand Down
50 changes: 46 additions & 4 deletions tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,45 @@
import coverup.segment as segment
import textwrap
import json
import ast


def test_get_global_imports():
code = textwrap.dedent("""\
import a, b
from c import d as e
import os
try:
from hashlib import sha1
except ImportError:
from sha import sha as sha1
class Foo:
import abc as os
def f():
if os.path.exists("foo"):
sha1(a.x, b.x, e.x)
""")

print(code)

def find_node(tree, name):
for n in ast.walk(tree):
if isinstance(n, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)) and n.name == name:
return n

tree = ast.parse(code)
# print(ast.dump(tree, indent=2))

f = find_node(tree, 'f')
assert set(segment.get_global_imports(tree, f)) == {
'import a, b',
'from hashlib import sha1',
'from c import d as e',
'import os'
}


class mockfs:
Expand All @@ -27,14 +66,14 @@ def __exit__(self, *args):

somecode_py = textwrap.dedent("""\
# Sample Python code used to create some tests.
import sys
import os
class Foo:
'''docstring...'''
@staticmethod
def foo():
pass
return os.path.exists('here')
def __init__(self):
'''initializes...'''
Expand Down Expand Up @@ -103,12 +142,13 @@ def test_large_limit_whole_class():
assert all(seg.begin < seg.end for seg in segs)

assert textwrap.dedent(segs[0].get_excerpt(tag_lines=False)) == textwrap.dedent("""\
import os
class Foo:
'''docstring...'''
@staticmethod
def foo():
pass
return os.path.exists('here')
def __init__(self):
'''initializes...'''
Expand Down Expand Up @@ -153,10 +193,11 @@ def test_small_limit():


assert textwrap.dedent(segs[0].get_excerpt(tag_lines=False)) == textwrap.dedent("""\
import os
class Foo:
@staticmethod
def foo():
pass
return os.path.exists('here')
""")

assert textwrap.dedent(segs[2].get_excerpt(tag_lines=False)) == textwrap.dedent("""\
Expand Down Expand Up @@ -262,3 +303,4 @@ def __init__(self, x):
self.x = x
7: self.y = 0
""")

0 comments on commit 43a1bcf

Please sign in to comment.