Skip to content

Commit

Permalink
- added looking up for any imports in code segment, indicated by
Browse files Browse the repository at this point in the history
  its starting line number;

- trimmed codeinfo response when the additional information about
  the path was unnecessary, to cut down on noise and tokens;
  • Loading branch information
jaltmayerpizzorno committed Aug 28, 2024
1 parent 2030a23 commit 8b548f2
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 93 deletions.
143 changes: 97 additions & 46 deletions src/coverup/codeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,49 @@ def helper(*args):
return helper


def _handle_import(module: ast.Module, node: ast.Import | ast.ImportFrom, name: T.List[str],
*, paths_seen: T.Set[Path] = None) -> T.Optional[T.List[ast.AST]]:

def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Module) -> T.List:
imp = copy.copy(node)
imp.names = [alias]
return [imp, mod]

if isinstance(node, ast.Import):
# import N
# import N.x imports N and N.x
# import a.b as N 'a.b' is renamed 'N'
for alias in node.names:
if alias.asname:
if alias.asname == name[0]:
mod = _load_module(alias.name)
if path := _find_name_path(mod, name[1:], paths_seen=paths_seen):
return transition(node, alias, mod) + path

elif (import_name := alias.name.split('.'))[0] == name[0]:
common_prefix = _common_prefix_len(import_name, name)
mod = _load_module('.'.join(import_name[:common_prefix]))
if path := _find_name_path(mod, name[common_prefix:], paths_seen=paths_seen):
return transition(node, alias, mod) + path

elif isinstance(node, ast.ImportFrom):
# from a.b import N either gets symbol N out of a.b, or imports a.b.N as N
# from a.b import c as N

for alias in node.names:
if (alias.asname if alias.asname else alias.name) == name[0]:
modname = _resolve_from_import(module.path, node)
_debug(f"looking for symbol ({[alias.name, *name[1:]]} in {modname})")
mod = _load_module(modname)
if path := _find_name_path(mod, [alias.name, *name[1:]], paths_seen=paths_seen):
return transition(node, alias, mod) + path

_debug(f"looking for module ({name[1:]} in {modname}.{alias.name})")
if (mod := _load_module(f"{modname}.{alias.name}")) and \
(path := _find_name_path(mod, name[1:], paths_seen=paths_seen)):
return transition(node, alias, mod) + path


def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[Path] = None) -> T.List[ast.AST]:
"""Looks for a symbol's definition by its name, returning the "path" of ast.ClassDef, ast.Import, etc.,
crossed to find it.
Expand All @@ -87,11 +130,6 @@ def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[
if module.path in paths_seen: return None
paths_seen.add(module.path)

def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Module) -> T.List:
imp = copy.copy(node)
imp.names = [alias]
return [imp, mod]

@_auto_stack
def find_name(node: ast.AST, name: T.List[str]) -> T.List[ast.AST]:
_debug(f"_find_name {name} in {ast.dump(node)}")
Expand All @@ -115,45 +153,22 @@ def find_name(node: ast.AST, name: T.List[str]) -> T.List[ast.AST]:
if (path := find_name(module, [base.id, *name[1:]])):
return path

elif isinstance(module, (ast.Function, ast.AsyncFunction)):
# searching within a function in the excerpt
for c in node.body:
_debug(f"{node.name} checking {ast.dump(c)}")
if (path := find_name(c, name[1:])):
return [node, *path]

return []

if (isinstance(node, ast.Assign) and
any(isinstance(n, ast.Name) and n.id == name[0] for t in node.targets for n in ast.walk(t))):
return [node] if len(name) == 1 else []

if isinstance(node, ast.Import):
# import N
# import N.x imports N and N.x
# import a.b as N 'a.b' is renamed 'N'
for alias in node.names:
if alias.asname:
if alias.asname == name[0]:
mod = _load_module(alias.name)
if path := _find_name_path(mod, name[1:], paths_seen=paths_seen):
return transition(node, alias, mod) + path

elif (import_name := alias.name.split('.'))[0] == name[0]:
common_prefix = _common_prefix_len(import_name, name)
mod = _load_module('.'.join(import_name[:common_prefix]))
if path := _find_name_path(mod, name[common_prefix:], paths_seen=paths_seen):
return transition(node, alias, mod) + path

elif isinstance(node, ast.ImportFrom):
# from a.b import N either gets symbol N out of a.b, or imports a.b.N as N
# from a.b import c as N

for alias in node.names:
if (alias.asname if alias.asname else alias.name) == name[0]:
modname = _resolve_from_import(module.path, node)
_debug(f"looking for symbol ({[alias.name, *name[1:]]} in {modname})")
mod = _load_module(modname)
if path := _find_name_path(mod, [alias.name, *name[1:]], paths_seen=paths_seen):
return transition(node, alias, mod) + path

_debug(f"looking for module ({name[1:]} in {modname}.{alias.name})")
if (mod := _load_module(f"{modname}.{alias.name}")) and \
(path := _find_name_path(mod, name[1:], paths_seen=paths_seen)):
return transition(node, alias, mod) + path
if isinstance(node, (ast.Import, ast.ImportFrom)):
if (path := _handle_import(module, node, name, paths_seen=paths_seen)):
return path

elif not isinstance(node, (ast.Expression, ast.Expr, ast.Name)):
for c in ast.iter_child_nodes(node):
Expand Down Expand Up @@ -250,7 +265,15 @@ def get_imports(n: ast.AST):
return imports


def get_info(module: ast.Module, name: str) -> T.Optional[str]:
def _find_excerpt(module: ast.Module, line: int) -> ast.AST:
for node in ast.walk(module):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
begin = min([node.lineno] + [d.lineno for d in node.decorator_list])
if begin <= line <= node.end_lineno:
return node


def get_info(module: ast.Module, name: str, *, line: int = 0) -> T.Optional[str]:
"""Returns summarized information on a class or function, following imports if necessary."""

key = name.split('.')
Expand All @@ -259,24 +282,52 @@ def get_info(module: ast.Module, name: str) -> T.Optional[str]:
(common_prefix := _common_prefix_len(module_fqn, key))):
key = key[common_prefix:]

if not (path := _find_name_path(module, key)):
# Couldn't find the name in the context of the given module;
# try to interpret it as an absolute fqn (GPT asks for that sometimes)
path = None
# first look in the excerpt node, such as the focal function, if specified
if (excerpt_node := _find_excerpt(module, line)):
for c in ast.walk(excerpt_node):
if (path := _handle_import(module, c, key)):
break

if not path:
# Try looking among globals and classes
path = _find_name_path(module, key)

if not path:
# Couldn't find the name in the context of the given module;
# try to interpret it as an absolute FQN (GPT asks for that sometimes)
key = name.split('.')
for i in range(len(key)-1, 0, -1):
if (mod := _load_module('.'.join(key[:i]))) and (path := _find_name_path(mod, key[i:])):
break


def any_import_as_or_import_in_class() -> bool:
return any(
isinstance(n, (ast.Import, ast.ImportFrom)) and (
n.names[0].asname or (
i > 0 and isinstance(path[i-1], ast.ClassDef)
)
)
for i, n in enumerate(path)
)

if path:
_summarize(path)
if any(isinstance(n, ast.Module) for n in path):
path = [module] + path

for i in range(len(path)):
_debug(f"path[{i}]={ast.dump(path[i])}")
for i in range(len(path)):
_debug(f"path[{i}]={ast.dump(path[i])}")

if any_import_as_or_import_in_class():
# include the full path for best context
path = [module] + path
else:
# just include the last module
modules = [i for i in range(len(path)) if isinstance(path[i], ast.Module)]
if modules:
path = path[modules[-1]:]

if any(isinstance(n, ast.Module) for n in path):
result = ""
for i in range(len(path)):
if isinstance(path[i], ast.Module):
Expand Down
2 changes: 1 addition & 1 deletion src/coverup/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_info(ctx: CodeSegment, name: str) -> str:

from .codeinfo import get_info, parse_file

if info := get_info(parse_file(ctx.path), name):
if info := get_info(parse_file(ctx.path), name, line=ctx.begin):
return "\"...\" below indicates omitted code.\n\n" + info

return f"Unable to obtain information on {name}."
Expand Down
81 changes: 35 additions & 46 deletions tests/test_codeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,9 @@ def a(self):
"""
))

# TODO it would be better if it showed class B(A)
tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'B.a') == textwrap.dedent("""\
in foo.py:
```python
from bar import A
```
in bar.py:
```python
class A:
Expand Down Expand Up @@ -394,11 +390,6 @@ class Foo:

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'foo.Foo') == textwrap.dedent('''\
in code.py:
```python
import foo
```
in foo/__init__.py:
```python
class Foo:
Expand Down Expand Up @@ -433,11 +424,6 @@ class Bar:

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'foo.Foo') == textwrap.dedent('''\
in code.py:
```python
import foo.bar
```
in foo/__init__.py:
```python
class Foo:
Expand All @@ -446,11 +432,6 @@ class Foo:
)

assert codeinfo.get_info(tree, 'foo.bar.Bar') == textwrap.dedent('''\
in code.py:
```python
import foo.bar
```
in foo/bar.py:
```python
class Bar:
Expand Down Expand Up @@ -519,11 +500,6 @@ class Bar:

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'bar.Bar') == textwrap.dedent('''\
in code.py:
```python
from foo import bar
```
in foo/__init__.py:
```python
class bar:
Expand Down Expand Up @@ -603,11 +579,6 @@ class Bar:

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'bar.Bar') == textwrap.dedent('''\
in code.py:
```python
from foo import bar
```
in foo/bar.py:
```python
class Bar:
Expand Down Expand Up @@ -703,16 +674,17 @@ class Baz:
)


@pytest.mark.xfail
def test_get_info_import_in_function(import_fixture):
def test_get_info_import_in_excerpt_function(import_fixture):
tmp_path = import_fixture

code = tmp_path / "code.py"
code.write_text(textwrap.dedent("""\
import os
def something():
from foo import Foo
def something_else():
from foo import Foo
Foo()
"""
))

Expand All @@ -723,18 +695,40 @@ class Foo:
"""
))

# XXX pass context here
tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'Foo') == textwrap.dedent('''\
in code.py:
assert codeinfo.get_info(tree, 'Foo', line=3) == textwrap.dedent('''\
in foo/__init__.py:
```python
def something():
...
class Foo:
pass
```'''
)


def test_get_info_import_in_excerpt_class(import_fixture):
tmp_path = import_fixture

code = tmp_path / "code.py"
code.write_text(textwrap.dedent("""\
import os
class X:
from foo import Foo
...
```
in foo/__init__.py:
def __init__(self, x: Foo):
self.x = x
"""
))

(tmp_path / "foo.py").write_text(textwrap.dedent("""\
class Foo:
pass
"""
))

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'Foo', line=3) == textwrap.dedent('''\
in foo.py:
```python
class Foo:
pass
Expand Down Expand Up @@ -847,11 +841,6 @@ class Bar:
tree = codeinfo.parse_file(code)

assert codeinfo.get_info(tree, 'foo.Foo') == textwrap.dedent('''\
in code.py:
```python
import foo
```
in foo/__init__.py:
```python
class Foo:
Expand Down

0 comments on commit 8b548f2

Please sign in to comment.