From d7adc46a17ea83ec66271003cce435d7d52f92f2 Mon Sep 17 00:00:00 2001 From: Avasam Date: Thu, 7 Nov 2024 17:37:14 -0500 Subject: [PATCH] Fix type issues in non stubs (#324) --- pyproject.toml | 5 +- tests/run_tests.py | 2 +- tests/sklearn/preprocessing_tests.py | 3 +- utils/count_ids.py | 8 +-- utils/validate_stubs.py | 73 ++++++++++++++++++---------- 5 files changed, 54 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 341a5028..2d9610b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,6 @@ reportOverlappingOverload = false reportSelfClsParameterName = false # Error reports to fix in code -reportArgumentType = "none" # TODO reportAssertTypeFailure = "none" # TODO reportAttributeAccessIssue = "none" # TODO reportGeneralTypeIssues = "none" # TODO @@ -114,7 +113,6 @@ reportInvalidTypeArguments = "none" # TODO reportInvalidTypeForm = "none" # TODO reportMissingImports = "none" # TODO reportUndefinedVariable = "none" # TODO -reportUnusedVariable = "none" # TODO [tool.mypy] # Target oldest supported Python version @@ -140,7 +138,7 @@ disable_error_code = [ # Not all imports in these stubs are gonna be typed "import-untyped", # TODO - "arg-type", + "assert-type", "assignment", "attr-defined", "import-not-found", @@ -150,7 +148,6 @@ disable_error_code = [ "operator", "override", "return", - "truthy-function", "type-var", "valid-type", "var-annotated", diff --git a/tests/run_tests.py b/tests/run_tests.py index 3c6111be..2695f798 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -4,7 +4,7 @@ from pathlib import Path -def install_requirements(test_folder: str): +def install_requirements(test_folder: Path): print("\nInstalling requirements...") return subprocess.run( (sys.executable, "-m", "pip", "install", "--upgrade", "-r", os.path.join(test_folder, "requirements.txt")) diff --git a/tests/sklearn/preprocessing_tests.py b/tests/sklearn/preprocessing_tests.py index 2fbd376b..d164f646 100644 --- a/tests/sklearn/preprocessing_tests.py +++ b/tests/sklearn/preprocessing_tests.py @@ -1,7 +1,8 @@ # pyright: reportUnknownVariableType=false # pyright: reportMissingTypeStubs=false -from typing import Any, assert_type +from typing import Any +from typing_extensions import assert_type from numpy import ndarray from scipy.sparse._csr import csr_matrix diff --git a/utils/count_ids.py b/utils/count_ids.py index d0a03c7c..ef0a89ec 100644 --- a/utils/count_ids.py +++ b/utils/count_ids.py @@ -1,6 +1,6 @@ #!/bin/python -"""Count IDs. +__doc__ = """Count IDs. Usage: count_ids [--path=] [--suffix=] [--pat=] [--uniq] @@ -29,9 +29,9 @@ def count(root, suffix, regex, uniq): filepat = "*" if suffix is None else "*." + suffix[suffix.find(".") + 1 :] if regex is None: regex = "[A-Za-z_][A-Za-z0-9_]*" - data = {} - loc = {} - ctx = {} + data: dict[str, int] = {} + loc: dict[str, str] = {} + ctx: dict[str, str] = {} try: prog = re.compile(regex) except Exception as e: diff --git a/utils/validate_stubs.py b/utils/validate_stubs.py index 9dcbf72c..c0612ab4 100644 --- a/utils/validate_stubs.py +++ b/utils/validate_stubs.py @@ -1,6 +1,6 @@ #!/bin/python -"""Validate Stubs. +__doc__ = """Validate Stubs. Usage: validate_stubs [--path=] [--class=] [--function=] @@ -14,31 +14,40 @@ --function= Restrict to the named function (or method if used with --class). --class= Restrict to the named class. """ +from __future__ import annotations import importlib +import importlib.machinery import inspect -import os import sys -import types import typing as _typing -from collections import namedtuple from enum import Enum -from operator import attrgetter, itemgetter -from typing import Any, Callable, List, Literal, NoReturn, Optional, Set, Tuple, _overload_dummy +from operator import attrgetter +from typing import ( + Callable, + List, + Literal, + Optional, + Tuple, + _overload_dummy, # type: ignore[attr-defined] # _overload_dummy not exposed +) import docopt -overloads = {} +overloads: dict[str, Callable] = {} def my_overload(func): key = func.__module__ + "." + func.__name__ if key not in overloads: - fn = lambda *args, **kwds: _overload_dummy(args, kwds) + + def fn(*args, **kwds): + _overload_dummy(args, kwds) + overloads[key] = fn fn.__overloads__ = [func] # type: ignore[attr-defined] # __overloads__ not exposed else: - overloads[key].__overloads__.append(func) + overloads[key].__overloads__.append(func) # type: ignore[attr-defined] # __overloads__ not exposed return overloads[key] @@ -90,15 +99,17 @@ class ItemType(Enum): FUNCTION = 3 PROPERTY = 4 - def __init__(self, file: str, module: str, name: str, object_: object, type_: ItemType, children: dict = None): + def __init__( + self, file: str, module: str, name: str, object_: object, type_: ItemType, children: dict[str, Item] | None = None + ): self.file = file self.module = module self.name = name self.object_ = object_ self.type_ = type_ - self.children = children + self.children = children or {} self.done = False - self.analog = None + self.analog: Item | None = None def ismodule(self): return self.type_ == Item.ItemType.MODULE @@ -114,11 +125,11 @@ def make_function(file: str, module: str, name: str, object_: object): return Item(file, module, name, object_, Item.ItemType.FUNCTION) @staticmethod - def make_class(file: str, module: str, name: str, object_: object, children: dict): + def make_class(file: str, module: str, name: str, object_: object, children: dict[str, Item]): return Item(file, module, name, object_, Item.ItemType.CLASS, children) @staticmethod - def make_module(file: str, module: str, name: str, object_: object, children: dict): + def make_module(file: str, module: str, name: str, object_: object, children: dict[str, Item]): return Item(file, module, name, object_, Item.ItemType.MODULE, children) @@ -126,7 +137,7 @@ def isfrompackage(v: object, path: str) -> bool: # Try to ensure the object lives below the root path and is not # imported from elsewhere. try: - f = inspect.getfile(v) + f = inspect.getfile(v) # type: ignore[arg-type] # Catching TypeError return f.startswith(path) except TypeError: # builtins or non-modules; for the latter we return True for now return not inspect.ismodule(v) @@ -173,7 +184,7 @@ def _gather(mpath: str, m: object, root: str, fpath: str, completed: set, items: mfpath = inspect.getfile(v) if mfpath.startswith(root): mfpath = mfpath[len(root) + 1 :] - members = dict() + members: dict[str, Item] = {} items[k] = Item.make_module(mfpath, mpath, k, v, members) _gather(mpath + "." + k, v, root, mfpath, completed, members) elif inspect.isfunction(v): @@ -181,7 +192,7 @@ def _gather(mpath: str, m: object, root: str, fpath: str, completed: set, items: print(f"{name} already has a function {k}") items[k] = Item.make_function(fpath, mpath, k, v) elif inspect.isclass(v): - members = dict() + members = {} items[k] = Item.make_class(fpath, mpath, k, v, members) for kc, vc in inspect.getmembers(v): if kc[0] != "_" and (inspect.isfunction(vc) or str(type(vc)) == ""): @@ -191,13 +202,13 @@ def _gather(mpath: str, m: object, root: str, fpath: str, completed: set, items: fpath = m.__dict__["__file__"] root = fpath[: fpath.rfind("/")] # fix for windows - members = dict() + members: dict[str, Item] = {} package = Item.make_module(fpath, "", name, m, members) _gather(name, m, root, fpath, set(), members) return package -def walk(tree: dict, fn: Callable, *args, postproc: Callable = None, path=None): +def walk(tree: dict, fn: Callable, *args, postproc: Callable | None = None, path=None): """ Walk the object tree and apply a function. If the function returns True, do not walk its children, @@ -220,15 +231,15 @@ def walk(tree: dict, fn: Callable, *args, postproc: Callable = None, path=None): def collect_items(root: Item) -> Tuple[List[Item], List[Item]]: - def _collect(path, name, node, functions, classes): + def _collect(path, name, node: Item, functions: List[Item], classes: List[Item]): if node.isclass(): classes.append(node) return True # Don't recurse elif node.isfunction(): functions.append(node) - functions = [] - classes = [] + functions: List[Item] = [] + classes: List[Item] = [] walk(root.children, _collect, functions, classes) functions = sorted(functions, key=attrgetter("name")) classes = sorted(classes, key=attrgetter("name")) @@ -259,6 +270,11 @@ def compare_args(real: Item, stub: Item, owner: Optional[str] = None): """ owner - name of owner class, if a member; else None if a top-level function """ + # Note that this isinstance check currently doesn't work for mypy: https://github.com/python/mypy/issues/11071 + if not (isinstance(stub.object_, Callable) and isinstance(real.object_, Callable)): # type: ignore[arg-type] + print(f"Can't compare args for non-callables. real: {real.object_}; stub: {stub.object_}") + return + if owner is None: owner = "" elif owner and owner[-1] != ".": @@ -267,15 +283,17 @@ def compare_args(real: Item, stub: Item, owner: Optional[str] = None): name = stub.name # if stub.object_ == _overload_dummy: if hasattr(stub.object_, "__overloads__"): - print(f"Can't validate @overloaded function {module}.{owner}{name} with {len(stub.object_.__overloads__)} overloads") + print( + f"Can't validate @overloaded function {module}.{owner}{name} with {len(stub.object_.__overloads__)} overloads" # pyright: ignore[reportFunctionMemberAccess] # __overloads__ not exposed + ) return try: - sc = stub.object_.__code__.co_argcount - ac = real.object_.__code__.co_argcount - sa = inspect.signature(stub.object_) + sc = stub.object_.__code__.co_argcount # type: ignore[attr-defined] # https://github.com/python/mypy/issues/11071 + ac = real.object_.__code__.co_argcount # type: ignore[attr-defined] # https://github.com/python/mypy/issues/11071 + sa = inspect.signature(stub.object_) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11071 sn = list(sa.parameters.keys()) - aa = inspect.signature(real.object_) + aa = inspect.signature(real.object_) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11071 an = list(aa.parameters.keys()) diff = "" @@ -355,6 +373,7 @@ def find_item( break i += 1 print(f"No {which} {type_} found with name {name}") + return None def compare_class(real: List[Item], stub: List[Item], class_: str):