Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checks for object in type function #772

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions pyupgrade/_plugins/type_bases_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import ast
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_open_paren
from pyupgrade._token_helpers import parse_call_args


def is_last_comma(tokens: list[Token], names: list[str]) -> bool:
last_arg = names[-1]
idx = [x.src for x in tokens].index(last_arg)
return tokens[idx + 1].src == ','


def remove_all(the_list: list[str], item: str) -> list[str]:
return [x for x in the_list if x != item]


def remove_line(
the_list: list[Token], sub_list: list[str], item: str, last_is_comma: bool,
) -> None:
is_last = sub_list[-1] == item
idx = [x.src for x in the_list].index(item)
line = the_list[idx].line
idxs = [i for i, x in enumerate(the_list) if x.line == line]
del the_list[min(idxs): max(idxs) + 1]
if is_last and not last_is_comma:
del the_list[min(idxs) - 2]


def remove_base_class_from_type_call(_: int, tokens: list[Token]) -> None:
type_start = find_open_paren(tokens, 0)
bases_start = find_open_paren(tokens, type_start + 1)
bases, end = parse_call_args(tokens, bases_start)
inner_tokens = tokens[bases_start + 1: end - 1]
new_lines = [x.src for x in inner_tokens if x.name == 'NL']
names = [x.src for x in inner_tokens if x.name == 'NAME']
last_is_comma = is_last_comma(tokens, names)
multi_line = len(new_lines) >= len(names)
targets = ['NAME', 'NL']
if multi_line:
targets.remove('NL')
inner_tokens = [x.src for x in inner_tokens if x.name in targets]
# This gets run if the function arguments are on over multiple lines
if multi_line:
remove_line(tokens, inner_tokens, 'object', last_is_comma)
return
inner_tokens = remove_all(inner_tokens, 'object')
# start by deleting all tokens, we will selectively add back
del tokens[bases_start + 1: end - 1]
count = 1
for i, token in enumerate(inner_tokens):
# Boolean value to see if the current item is the last
last = i == len(inner_tokens) - 1
tokens.insert(bases_start + count, Token('NAME', token))
count += 1
# adds a comma and a space if the current item is not the last
if not last and token != '\n':
tokens.insert(bases_start + count, Token('UNIMPORTANT_WS', ' '))
tokens.insert(bases_start + count, Token('OP', ','))
count += 2
# If the lenght is only one, or the last one had a comma, add a comma
elif (last and last_is_comma) or len(inner_tokens) == 1:
tokens.insert(bases_start + count, Token('OP', ','))


@register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
isinstance(node.func, ast.Name) and
node.func.id == 'type' and
len(node.args) > 1 and
isinstance(node.args[1], ast.Tuple) and
any(
isinstance(elt, ast.Name) and elt.id == 'object'
for elt in node.args[1].elts
)
):
for base in node.args[1].elts:
if isinstance(base, ast.Name) and base.id == 'object':
yield ast_to_offset(base), remove_base_class_from_type_call
142 changes: 142 additions & 0 deletions tests/features/type_bases_object_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
'src',
['A = type("A", (), {})', 'B = type("B", (int,), {}'],
)
def test_fix_type_bases_object_noop(src):
ret = _fix_plugins(src, settings=Settings())
assert ret == src


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'A = type("A", (object,), {})',
'A = type("A", (), {})',
id='only object base class',
),
pytest.param(
'B = type("B", (object, tuple), {})',
'B = type("B", (tuple,), {})',
id='two base classes, object first',
),
pytest.param(
'C = type("C", (object, foo, bar), {})',
'C = type("C", (foo, bar), {})',
id='three base classes, object first',
),
pytest.param(
'D = type("D", (tuple, object), {})',
'D = type("D", (tuple,), {})',
id='two base classes, object last',
),
pytest.param(
'E = type("E", (foo, bar, object), {})',
'E = type("E", (foo, bar), {})',
id='three base classes, object last',
),
pytest.param(
'F = type(\n "F",\n (object, tuple),\n {}\n)',
'F = type(\n "F",\n (tuple,),\n {}\n)',
id='newline and indent, two base classes',
),
pytest.param(
'G = type(\n "G",\n (\n object,\n class1,\n'
' class2,\n class3,\n class4,\n class5'
',\n class6,\n class7,\n class8,\n '
'class9,\n classA,\n classB\n ),\n {}\n)',
'G = type(\n "G",\n (\n class1,\n class2,\n'
' class3,\n class4,\n class5,\n class6'
',\n class7,\n class8,\n class9,\n '
'classA,\n classB\n ),\n {}\n)',
id='newline and also inside classes tuple',
),
pytest.param(
'H = type(\n "H",\n (tuple, object),\n {}\n)',
'H = type(\n "H",\n (tuple,),\n {}\n)',
id='newline and indent, two base classes, object last',
),
pytest.param(
'I = type(\n "I",\n (\n class1,\n'
' class2,\n class3,\n class4,\n class5'
',\n class6,\n class7,\n class8,\n '
'class9,\n classA,\n object\n ),\n {}\n)',
'I = type(\n "I",\n (\n class1,\n class2,\n'
' class3,\n class4,\n class5,\n class6'
',\n class7,\n class8,\n class9,\n '
'classA\n ),\n {}\n)',
id='newline and also inside classes tuple, object last',
),
pytest.param(
'J = type("J", (object, foo, bar,), {})',
'J = type("J", (foo, bar,), {})',
id='trailing comma, object first',
),
pytest.param(
'K = type("K", (foo, bar, object,), {})',
'K = type("K", (foo, bar,), {})',
id='trailing comma, object last',
),
pytest.param(
'L = type(\n "L",\n (foo, bar, object,),\n {}\n)',
'L = type(\n "L",\n (foo, bar,),\n {}\n)',
id='trailing comma, newline and indent, object last',
),
pytest.param(
'M = type(\n "M",\n (\n class1,\n'
' class2,\n class3,\n class4,\n class5'
',\n class6,\n class7,\n class8,\n '
'class9,\n classA,\n object,\n ),\n {}\n)',
'M = type(\n "M",\n (\n class1,\n class2,\n'
' class3,\n class4,\n class5,\n class6'
',\n class7,\n class8,\n class9,\n '
'classA,\n ),\n {}\n)',
id='trailing comma, '
'newline and also inside classes tuple, '
'object last',
),
pytest.param(
'O = type("O", (foo, object, bar), {})',
'O = type("O", (foo, bar), {})',
id='object in the middle',
),
pytest.param(
'R = type("R", (object,tuple), {})',
'R = type("R", (tuple,), {})',
id='no spaces, object first',
),
pytest.param(
'S = type("S", (tuple,object), {})',
'S = type("S", (tuple,), {})',
id='no spaces, object last',
),
pytest.param(
'U = type("U", (tuple, object,), {})',
'U = type("U", (tuple,), {})',
id='trailing comma, object last, two classes',
),
pytest.param(
'P = type( \n"P",\n (\n foo,\n object,'
'\n bar\n ),\n {}\n)',
'P = type( \n"P",\n (\n foo,\n bar\n '
'),\n {}\n)',
id='newline and also inside classes tuple, object in the middle',
),
pytest.param(
'Q = type(\n "Q",\n (foo, object, bar),\n {}\n)',
'Q = type(\n "Q",\n (foo, bar),\n {}\n)',
id='newline and indent, object in the middle',
),
),
)
def test_fix_type_bases_object(s, expected):
ret = _fix_plugins(s, settings=Settings())
assert ret == expected
Loading