Skip to content

Commit 8943547

Browse files
authored
Merge pull request #898 from asottile/constant-fold-types
constant fold isinstance / issubclass / except
2 parents 0d46cba + f926532 commit 8943547

File tree

7 files changed

+206
-36
lines changed

7 files changed

+206
-36
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,22 @@ A fix for [python-modernize/python-modernize#178]
154154

155155
[python-modernize/python-modernize#178]: https://github.com/python-modernize/python-modernize/issues/178
156156

157+
### constant fold `isinstance` / `issubclass` / `except`
158+
159+
```diff
160+
-isinstance(x, (int, int))
161+
+isinstance(x, int)
162+
163+
-issubclass(y, (str, str))
164+
+issubclass(y, str)
165+
166+
try:
167+
raises()
168+
-except (Error1, Error1, Error2):
169+
+except (Error1, Error2):
170+
pass
171+
```
172+
157173
### unittest deprecated aliases
158174

159175
Rewrites [deprecated unittest method aliases](https://docs.python.org/3/library/unittest.html#deprecated-aliases) to their non-deprecated forms.

pyupgrade/_ast_helpers.py

+10
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,13 @@ def is_async_listcomp(node: ast.ListComp) -> bool:
5656
any(gen.is_async for gen in node.generators) or
5757
contains_await(node)
5858
)
59+
60+
61+
def is_type_check(node: ast.AST) -> bool:
62+
return (
63+
isinstance(node, ast.Call) and
64+
isinstance(node.func, ast.Name) and
65+
node.func.id in {'isinstance', 'issubclass'} and
66+
len(node.args) == 2 and
67+
not has_starargs(node)
68+
)

pyupgrade/_plugins/constant_fold.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import Iterable
5+
6+
from tokenize_rt import Offset
7+
8+
from pyupgrade._ast_helpers import ast_to_offset
9+
from pyupgrade._ast_helpers import is_type_check
10+
from pyupgrade._data import register
11+
from pyupgrade._data import State
12+
from pyupgrade._data import TokenFunc
13+
from pyupgrade._token_helpers import constant_fold_tuple
14+
15+
16+
def _to_name(node: ast.AST) -> str | None:
17+
if isinstance(node, ast.Name):
18+
return node.id
19+
elif isinstance(node, ast.Attribute):
20+
base = _to_name(node.value)
21+
if base is None:
22+
return None
23+
else:
24+
return f'{base}.{node.attr}'
25+
else:
26+
return None
27+
28+
29+
def _can_constant_fold(node: ast.Tuple) -> bool:
30+
seen = set()
31+
for el in node.elts:
32+
name = _to_name(el)
33+
if name is not None:
34+
if name in seen:
35+
return True
36+
else:
37+
seen.add(name)
38+
else:
39+
return False
40+
41+
42+
def _cbs(node: ast.AST | None) -> Iterable[tuple[Offset, TokenFunc]]:
43+
if isinstance(node, ast.Tuple) and _can_constant_fold(node):
44+
yield ast_to_offset(node), constant_fold_tuple
45+
46+
47+
@register(ast.Call)
48+
def visit_Call(
49+
state: State,
50+
node: ast.Call,
51+
parent: ast.AST,
52+
) -> Iterable[tuple[Offset, TokenFunc]]:
53+
if is_type_check(node):
54+
yield from _cbs(node.args[1])
55+
56+
57+
@register(ast.Try)
58+
def visit_Try(
59+
state: State,
60+
node: ast.Try,
61+
parent: ast.AST,
62+
) -> Iterable[tuple[Offset, TokenFunc]]:
63+
for handler in node.handlers:
64+
yield from _cbs(handler.type)

pyupgrade/_plugins/exceptions.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pyupgrade._data import State
1414
from pyupgrade._data import TokenFunc
1515
from pyupgrade._data import Version
16-
from pyupgrade._token_helpers import arg_str
16+
from pyupgrade._token_helpers import constant_fold_tuple
1717
from pyupgrade._token_helpers import find_op
1818
from pyupgrade._token_helpers import parse_call_args
1919
from pyupgrade._token_helpers import replace_name
@@ -45,34 +45,13 @@ def _fix_except(
4545
*,
4646
at_idx: dict[int, _Target],
4747
) -> None:
48-
# find all the arg strs in the tuple
49-
except_index = i
50-
while tokens[except_index].src != 'except':
51-
except_index -= 1
52-
start = find_op(tokens, except_index, '(')
48+
start = find_op(tokens, i, '(')
5349
func_args, end = parse_call_args(tokens, start)
5450

55-
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
51+
for i, target in reversed(at_idx.items()):
52+
tokens[slice(*func_args[i])] = [Token('NAME', target.target)]
5653

57-
# rewrite the block without dupes
58-
args = []
59-
for i, arg in enumerate(arg_strs):
60-
target = at_idx.get(i)
61-
if target is not None:
62-
args.append(target.target)
63-
else:
64-
args.append(arg)
65-
66-
unique_args = tuple(dict.fromkeys(args))
67-
68-
if len(unique_args) > 1:
69-
joined = '({})'.format(', '.join(unique_args))
70-
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
71-
joined = f' {unique_args[0]}'
72-
else:
73-
joined = unique_args[0]
74-
75-
tokens[start:end] = [Token('CODE', joined)]
54+
constant_fold_tuple(start, tokens)
7655

7756

7857
def _get_rewrite(

pyupgrade/_plugins/six_simple.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tokenize_rt import Offset
88

99
from pyupgrade._ast_helpers import ast_to_offset
10+
from pyupgrade._ast_helpers import is_type_check
1011
from pyupgrade._data import register
1112
from pyupgrade._data import State
1213
from pyupgrade._data import TokenFunc
@@ -36,14 +37,6 @@
3637
}
3738

3839

39-
def _is_type_check(node: ast.AST | None) -> bool:
40-
return (
41-
isinstance(node, ast.Call) and
42-
isinstance(node.func, ast.Name) and
43-
node.func.id in {'isinstance', 'issubclass'}
44-
)
45-
46-
4740
@register(ast.Attribute)
4841
def visit_Attribute(
4942
state: State,
@@ -62,7 +55,7 @@ def visit_Attribute(
6255
):
6356
return
6457

65-
if node.attr in NAMES_TYPE_CTX and _is_type_check(parent):
58+
if node.attr in NAMES_TYPE_CTX and is_type_check(parent):
6659
new = NAMES_TYPE_CTX[node.attr]
6760
else:
6861
new = NAMES[node.attr]
@@ -106,7 +99,7 @@ def visit_Name(
10699
):
107100
return
108101

109-
if node.id in NAMES_TYPE_CTX and _is_type_check(parent):
102+
if node.id in NAMES_TYPE_CTX and is_type_check(parent):
110103
new = NAMES_TYPE_CTX[node.id]
111104
else:
112105
new = NAMES[node.id]

pyupgrade/_token_helpers.py

+17
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,23 @@ def replace_argument(
470470
tokens[start_idx:end_idx] = [Token('SRC', new)]
471471

472472

473+
def constant_fold_tuple(i: int, tokens: list[Token]) -> None:
474+
start = find_op(tokens, i, '(')
475+
func_args, end = parse_call_args(tokens, start)
476+
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
477+
478+
unique_args = tuple(dict.fromkeys(arg_strs))
479+
480+
if len(unique_args) > 1:
481+
joined = '({})'.format(', '.join(unique_args))
482+
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
483+
joined = f' {unique_args[0]}'
484+
else:
485+
joined = unique_args[0]
486+
487+
tokens[start:end] = [Token('CODE', joined)]
488+
489+
473490
def has_space_before(i: int, tokens: list[Token]) -> bool:
474491
return i >= 1 and tokens[i - 1].name in {UNIMPORTANT_WS, 'INDENT'}
475492

tests/features/constant_fold_test.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pyupgrade._data import Settings
6+
from pyupgrade._main import _fix_plugins
7+
8+
9+
@pytest.mark.parametrize(
10+
's',
11+
(
12+
pytest.param(
13+
'isinstance(x, str)',
14+
id='isinstance nothing duplicated',
15+
),
16+
pytest.param(
17+
'issubclass(x, str)',
18+
id='issubclass nothing duplicated',
19+
),
20+
pytest.param(
21+
'try: ...\n'
22+
'except Exception: ...\n',
23+
id='try-except nothing duplicated',
24+
),
25+
pytest.param(
26+
'isinstance(x, (str, (str,)))',
27+
id='only consider flat tuples',
28+
),
29+
pytest.param(
30+
'isinstance(x, (f(), a().g))',
31+
id='only consider names and dotted names',
32+
),
33+
),
34+
)
35+
def test_constant_fold_noop(s):
36+
assert _fix_plugins(s, settings=Settings()) == s
37+
38+
39+
@pytest.mark.parametrize(
40+
('s', 'expected'),
41+
(
42+
pytest.param(
43+
'isinstance(x, (str, str, int))',
44+
45+
'isinstance(x, (str, int))',
46+
47+
id='isinstance',
48+
),
49+
pytest.param(
50+
'issubclass(x, (str, str, int))',
51+
52+
'issubclass(x, (str, int))',
53+
54+
id='issubclass',
55+
),
56+
pytest.param(
57+
'try: ...\n'
58+
'except (Exception, Exception, TypeError): ...\n',
59+
60+
'try: ...\n'
61+
'except (Exception, TypeError): ...\n',
62+
63+
id='except',
64+
),
65+
66+
pytest.param(
67+
'isinstance(x, (str, str))',
68+
69+
'isinstance(x, str)',
70+
71+
id='folds to 1',
72+
),
73+
74+
pytest.param(
75+
'isinstance(x, (a.b, a.b, a.c))',
76+
'isinstance(x, (a.b, a.c))',
77+
id='folds dotted names',
78+
),
79+
pytest.param(
80+
'try: ...\n'
81+
'except(a, a): ...\n',
82+
83+
'try: ...\n'
84+
'except a: ...\n',
85+
86+
id='deduplication to 1 does not cause syntax error with except',
87+
),
88+
),
89+
)
90+
def test_constant_fold(s, expected):
91+
assert _fix_plugins(s, settings=Settings()) == expected

0 commit comments

Comments
 (0)