Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubtest: get better signatures for __init__ of C classes #18259

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
198 changes: 198 additions & 0 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from __future__ import annotations

import argparse
import ast
import collections.abc
import copy
import enum
import functools
import importlib
import importlib.machinery
import inspect
import itertools
import os
import pkgutil
import re
Expand Down Expand Up @@ -1526,7 +1528,203 @@ def is_read_only_property(runtime: object) -> bool:
return isinstance(runtime, property) and runtime.fset is None


def _signature_fromstr(
cls: type[inspect.Signature], obj: Any, s: str, skip_bound_arg: bool = True
) -> inspect.Signature:
"""Private helper to parse content of '__text_signature__'
and return a Signature based on it.

This is a copy of inspect._signature_fromstr from 3.13, which we need
for python/cpython#115270, an important fix for working with
built-in instance methods.
"""
Parameter = cls._parameter_cls # type: ignore[attr-defined]

if sys.version_info >= (3, 12):
clean_signature, self_parameter = inspect._signature_strip_non_python_syntax(s) # type: ignore[attr-defined]
else:
clean_signature, self_parameter, last_positional_only = inspect._signature_strip_non_python_syntax(s) # type: ignore[attr-defined]

program = "def foo" + clean_signature + ": pass"

try:
module_ast = ast.parse(program)
except SyntaxError:
module_ast = None

if not isinstance(module_ast, ast.Module):
raise ValueError("{!r} builtin has invalid signature".format(obj))

f = module_ast.body[0]
assert isinstance(f, ast.FunctionDef)

parameters = []
empty = Parameter.empty

module = None
module_dict: dict[str, Any] = {}

module_name = getattr(obj, "__module__", None)
if not module_name:
objclass = getattr(obj, "__objclass__", None)
module_name = getattr(objclass, "__module__", None)

if module_name:
module = sys.modules.get(module_name, None)
if module:
module_dict = module.__dict__
sys_module_dict = sys.modules.copy()

def parse_name(node: ast.arg) -> str:
assert isinstance(node, ast.arg)
if node.annotation is not None:
raise ValueError("Annotations are not currently supported")
return node.arg

def wrap_value(s: str) -> ast.Constant:
try:
value = eval(s, module_dict)
except NameError:
try:
value = eval(s, sys_module_dict)
except NameError as err:
raise ValueError from err

if isinstance(value, (str, int, float, bytes, bool, type(None))):
return ast.Constant(value)
raise ValueError

class RewriteSymbolics(ast.NodeTransformer):
def visit_Attribute(self, node: ast.Attribute) -> Any: # noqa: N802
a = []
n: ast.expr = node
while isinstance(n, ast.Attribute):
a.append(n.attr)
n = n.value
if not isinstance(n, ast.Name):
raise ValueError
a.append(n.id)
value = ".".join(reversed(a))
return wrap_value(value)

def visit_Name(self, node: ast.Name) -> Any: # noqa: N802
if not isinstance(node.ctx, ast.Load):
raise ValueError()
return wrap_value(node.id)

def visit_BinOp(self, node: ast.BinOp) -> Any: # noqa: N802
# Support constant folding of a couple simple binary operations
# commonly used to define default values in text signatures
left = self.visit(node.left)
right = self.visit(node.right)
if not isinstance(left, ast.Constant) or not isinstance(right, ast.Constant):
raise ValueError
if isinstance(node.op, ast.Add):
return ast.Constant(left.value + right.value)
elif isinstance(node.op, ast.Sub):
return ast.Constant(left.value - right.value)
elif isinstance(node.op, ast.BitOr):
return ast.Constant(left.value | right.value)
raise ValueError

def p(name_node: ast.arg, default_node: Any, default: Any = empty) -> None:
name = parse_name(name_node)
if default_node and default_node is not inspect._empty:
try:
default_node = RewriteSymbolics().visit(default_node)
default = ast.literal_eval(default_node)
except ValueError:
raise ValueError("{!r} builtin has invalid signature".format(obj)) from None
parameters.append(Parameter(name, kind, default=default, annotation=empty))

# non-keyword-only parameters
if sys.version_info >= (3, 12):
total_non_kw_args = len(f.args.posonlyargs) + len(f.args.args)
required_non_kw_args = total_non_kw_args - len(f.args.defaults)
defaults = itertools.chain(itertools.repeat(None, required_non_kw_args), f.args.defaults)

kind = Parameter.POSITIONAL_ONLY
for name, default in zip(f.args.posonlyargs, defaults):
p(name, default)

kind = Parameter.POSITIONAL_OR_KEYWORD
for name, default in zip(f.args.args, defaults):
p(name, default)

else:
args = reversed(f.args.args)
defaults = reversed(f.args.defaults)
iter = itertools.zip_longest(args, defaults, fillvalue=None)
if last_positional_only is not None:
kind = Parameter.POSITIONAL_ONLY
else:
kind = Parameter.POSITIONAL_OR_KEYWORD
for i, (name, default) in enumerate(reversed(list(iter))):
assert name is not None
p(name, default)
if i == last_positional_only:
kind = Parameter.POSITIONAL_OR_KEYWORD

# *args
if f.args.vararg:
kind = Parameter.VAR_POSITIONAL
p(f.args.vararg, empty)

# keyword-only arguments
kind = Parameter.KEYWORD_ONLY
for name, default in zip(f.args.kwonlyargs, f.args.kw_defaults):
p(name, default)

# **kwargs
if f.args.kwarg:
kind = Parameter.VAR_KEYWORD
p(f.args.kwarg, empty)

if self_parameter is not None:
# Possibly strip the bound argument:
# - We *always* strip first bound argument if
# it is a module.
# - We don't strip first bound argument if
# skip_bound_arg is False.
assert parameters
_self = getattr(obj, "__self__", None)
self_isbound = _self is not None
self_ismodule = inspect.ismodule(_self)
if self_isbound and (self_ismodule or skip_bound_arg):
parameters.pop(0)
else:
# for builtins, self parameter is always positional-only!
p = parameters[0].replace(kind=Parameter.POSITIONAL_ONLY)
parameters[0] = p

return cls(parameters, return_annotation=cls.empty)


def safe_inspect_signature(runtime: Any) -> inspect.Signature | None:
if (
hasattr(runtime, "__name__")
and runtime.__name__ == "__init__"
tungol marked this conversation as resolved.
Show resolved Hide resolved
and hasattr(runtime, "__text_signature__")
and runtime.__text_signature__ == "($self, /, *args, **kwargs)"
and hasattr(runtime, "__objclass__")
and runtime.__objclass__ is not object
and hasattr(runtime.__objclass__, "__text_signature__")
and runtime.__objclass__.__text_signature__ is not None
):
# This is an __init__ method with the generic C-class signature.
# In this case, the underlying class usually has a better signature,
# which we can convert into an __init__ signature by adding $self
# at the start. If we hit an error, failover to the normal
# path without trying to recover.
if "/" in runtime.__objclass__.__text_signature__:
new_sig = f"($self, {runtime.__objclass__.__text_signature__[1:]}"
else:
new_sig = f"($self, /, {runtime.__objclass__.__text_signature__[1:]}"
try:
return _signature_fromstr(inspect.Signature, runtime, new_sig)
tungol marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
pass

try:
try:
return inspect.signature(runtime)
Expand Down
Loading