Skip to content

Commit

Permalink
Fix type issues in non stubs (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
Avasam authored Nov 7, 2024
1 parent 97e16f0 commit d7adc46
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 37 deletions.
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,13 @@ reportOverlappingOverload = false
reportSelfClsParameterName = false

# Error reports to fix in code
reportArgumentType = "none" # TODO
reportAssertTypeFailure = "none" # TODO
reportAttributeAccessIssue = "none" # TODO
reportGeneralTypeIssues = "none" # TODO
reportInvalidTypeArguments = "none" # TODO
reportInvalidTypeForm = "none" # TODO
reportMissingImports = "none" # TODO
reportUndefinedVariable = "none" # TODO
reportUnusedVariable = "none" # TODO

[tool.mypy]
# Target oldest supported Python version
Expand All @@ -140,7 +138,7 @@ disable_error_code = [
# Not all imports in these stubs are gonna be typed
"import-untyped",
# TODO
"arg-type",
"assert-type",
"assignment",
"attr-defined",
"import-not-found",
Expand All @@ -150,7 +148,6 @@ disable_error_code = [
"operator",
"override",
"return",
"truthy-function",
"type-var",
"valid-type",
"var-annotated",
Expand Down
2 changes: 1 addition & 1 deletion tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path


def install_requirements(test_folder: str):
def install_requirements(test_folder: Path):
print("\nInstalling requirements...")
return subprocess.run(
(sys.executable, "-m", "pip", "install", "--upgrade", "-r", os.path.join(test_folder, "requirements.txt"))
Expand Down
3 changes: 2 additions & 1 deletion tests/sklearn/preprocessing_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# pyright: reportUnknownVariableType=false
# pyright: reportMissingTypeStubs=false

from typing import Any, assert_type
from typing import Any
from typing_extensions import assert_type

from numpy import ndarray
from scipy.sparse._csr import csr_matrix
Expand Down
8 changes: 4 additions & 4 deletions utils/count_ids.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/python

"""Count IDs.
__doc__ = """Count IDs.
Usage:
count_ids [--path=<root>] [--suffix=<filesuffix>] [--pat=<pat>] [--uniq]
Expand Down Expand Up @@ -29,9 +29,9 @@ def count(root, suffix, regex, uniq):
filepat = "*" if suffix is None else "*." + suffix[suffix.find(".") + 1 :]
if regex is None:
regex = "[A-Za-z_][A-Za-z0-9_]*"
data = {}
loc = {}
ctx = {}
data: dict[str, int] = {}
loc: dict[str, str] = {}
ctx: dict[str, str] = {}
try:
prog = re.compile(regex)
except Exception as e:
Expand Down
73 changes: 46 additions & 27 deletions utils/validate_stubs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/python

"""Validate Stubs.
__doc__ = """Validate Stubs.
Usage:
validate_stubs <package> [--path=<stubpath>] [--class=<c>] [--function=<f>]
Expand All @@ -14,31 +14,40 @@
--function=<f> Restrict to the named function (or method if used with --class).
--class=<c> Restrict to the named class.
"""
from __future__ import annotations

import importlib
import importlib.machinery
import inspect
import os
import sys
import types
import typing as _typing
from collections import namedtuple
from enum import Enum
from operator import attrgetter, itemgetter
from typing import Any, Callable, List, Literal, NoReturn, Optional, Set, Tuple, _overload_dummy
from operator import attrgetter
from typing import (
Callable,
List,
Literal,
Optional,
Tuple,
_overload_dummy, # type: ignore[attr-defined] # _overload_dummy not exposed
)

import docopt

overloads = {}
overloads: dict[str, Callable] = {}


def my_overload(func):
key = func.__module__ + "." + func.__name__
if key not in overloads:
fn = lambda *args, **kwds: _overload_dummy(args, kwds)

def fn(*args, **kwds):
_overload_dummy(args, kwds)

overloads[key] = fn
fn.__overloads__ = [func] # type: ignore[attr-defined] # __overloads__ not exposed
else:
overloads[key].__overloads__.append(func)
overloads[key].__overloads__.append(func) # type: ignore[attr-defined] # __overloads__ not exposed
return overloads[key]


Expand Down Expand Up @@ -90,15 +99,17 @@ class ItemType(Enum):
FUNCTION = 3
PROPERTY = 4

def __init__(self, file: str, module: str, name: str, object_: object, type_: ItemType, children: dict = None):
def __init__(
self, file: str, module: str, name: str, object_: object, type_: ItemType, children: dict[str, Item] | None = None
):
self.file = file
self.module = module
self.name = name
self.object_ = object_
self.type_ = type_
self.children = children
self.children = children or {}
self.done = False
self.analog = None
self.analog: Item | None = None

def ismodule(self):
return self.type_ == Item.ItemType.MODULE
Expand All @@ -114,19 +125,19 @@ def make_function(file: str, module: str, name: str, object_: object):
return Item(file, module, name, object_, Item.ItemType.FUNCTION)

@staticmethod
def make_class(file: str, module: str, name: str, object_: object, children: dict):
def make_class(file: str, module: str, name: str, object_: object, children: dict[str, Item]):
return Item(file, module, name, object_, Item.ItemType.CLASS, children)

@staticmethod
def make_module(file: str, module: str, name: str, object_: object, children: dict):
def make_module(file: str, module: str, name: str, object_: object, children: dict[str, Item]):
return Item(file, module, name, object_, Item.ItemType.MODULE, children)


def isfrompackage(v: object, path: str) -> bool:
# Try to ensure the object lives below the root path and is not
# imported from elsewhere.
try:
f = inspect.getfile(v)
f = inspect.getfile(v) # type: ignore[arg-type] # Catching TypeError
return f.startswith(path)
except TypeError: # builtins or non-modules; for the latter we return True for now
return not inspect.ismodule(v)
Expand Down Expand Up @@ -173,15 +184,15 @@ def _gather(mpath: str, m: object, root: str, fpath: str, completed: set, items:
mfpath = inspect.getfile(v)
if mfpath.startswith(root):
mfpath = mfpath[len(root) + 1 :]
members = dict()
members: dict[str, Item] = {}
items[k] = Item.make_module(mfpath, mpath, k, v, members)
_gather(mpath + "." + k, v, root, mfpath, completed, members)
elif inspect.isfunction(v):
if k in items:
print(f"{name} already has a function {k}")
items[k] = Item.make_function(fpath, mpath, k, v)
elif inspect.isclass(v):
members = dict()
members = {}
items[k] = Item.make_class(fpath, mpath, k, v, members)
for kc, vc in inspect.getmembers(v):
if kc[0] != "_" and (inspect.isfunction(vc) or str(type(vc)) == "<class 'property'>"):
Expand All @@ -191,13 +202,13 @@ def _gather(mpath: str, m: object, root: str, fpath: str, completed: set, items:

fpath = m.__dict__["__file__"]
root = fpath[: fpath.rfind("/")] # fix for windows
members = dict()
members: dict[str, Item] = {}
package = Item.make_module(fpath, "", name, m, members)
_gather(name, m, root, fpath, set(), members)
return package


def walk(tree: dict, fn: Callable, *args, postproc: Callable = None, path=None):
def walk(tree: dict, fn: Callable, *args, postproc: Callable | None = None, path=None):
"""
Walk the object tree and apply a function.
If the function returns True, do not walk its children,
Expand All @@ -220,15 +231,15 @@ def walk(tree: dict, fn: Callable, *args, postproc: Callable = None, path=None):


def collect_items(root: Item) -> Tuple[List[Item], List[Item]]:
def _collect(path, name, node, functions, classes):
def _collect(path, name, node: Item, functions: List[Item], classes: List[Item]):
if node.isclass():
classes.append(node)
return True # Don't recurse
elif node.isfunction():
functions.append(node)

functions = []
classes = []
functions: List[Item] = []
classes: List[Item] = []
walk(root.children, _collect, functions, classes)
functions = sorted(functions, key=attrgetter("name"))
classes = sorted(classes, key=attrgetter("name"))
Expand Down Expand Up @@ -259,6 +270,11 @@ def compare_args(real: Item, stub: Item, owner: Optional[str] = None):
"""
owner - name of owner class, if a member; else None if a top-level function
"""
# Note that this isinstance check currently doesn't work for mypy: https://github.com/python/mypy/issues/11071
if not (isinstance(stub.object_, Callable) and isinstance(real.object_, Callable)): # type: ignore[arg-type]
print(f"Can't compare args for non-callables. real: {real.object_}; stub: {stub.object_}")
return

if owner is None:
owner = ""
elif owner and owner[-1] != ".":
Expand All @@ -267,15 +283,17 @@ def compare_args(real: Item, stub: Item, owner: Optional[str] = None):
name = stub.name
# if stub.object_ == _overload_dummy:
if hasattr(stub.object_, "__overloads__"):
print(f"Can't validate @overloaded function {module}.{owner}{name} with {len(stub.object_.__overloads__)} overloads")
print(
f"Can't validate @overloaded function {module}.{owner}{name} with {len(stub.object_.__overloads__)} overloads" # pyright: ignore[reportFunctionMemberAccess] # __overloads__ not exposed
)
return

try:
sc = stub.object_.__code__.co_argcount
ac = real.object_.__code__.co_argcount
sa = inspect.signature(stub.object_)
sc = stub.object_.__code__.co_argcount # type: ignore[attr-defined] # https://github.com/python/mypy/issues/11071
ac = real.object_.__code__.co_argcount # type: ignore[attr-defined] # https://github.com/python/mypy/issues/11071
sa = inspect.signature(stub.object_) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11071
sn = list(sa.parameters.keys())
aa = inspect.signature(real.object_)
aa = inspect.signature(real.object_) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11071
an = list(aa.parameters.keys())
diff = ""

Expand Down Expand Up @@ -355,6 +373,7 @@ def find_item(
break
i += 1
print(f"No {which} {type_} found with name {name}")
return None


def compare_class(real: List[Item], stub: List[Item], class_: str):
Expand Down

0 comments on commit d7adc46

Please sign in to comment.