Skip to content

Commit 19b1966

Browse files
authored
Merge pull request #885 from mxr/timeout-errors
Unify exception replacing and rewrite some TimeoutError cases
2 parents 3bbf781 + 719c224 commit 19b1966

File tree

5 files changed

+431
-652
lines changed

5 files changed

+431
-652
lines changed

pyupgrade/_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class State(NamedTuple):
3838

3939
RECORD_FROM_IMPORTS = frozenset((
4040
'__future__',
41+
'asyncio',
4142
'functools',
4243
'mmap',
4344
'os',

pyupgrade/_plugins/exceptions.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import functools
5+
from typing import Iterable
6+
from typing import NamedTuple
7+
8+
from tokenize_rt import Offset
9+
from tokenize_rt import Token
10+
11+
from pyupgrade._ast_helpers import ast_to_offset
12+
from pyupgrade._data import register
13+
from pyupgrade._data import State
14+
from pyupgrade._data import TokenFunc
15+
from pyupgrade._data import Version
16+
from pyupgrade._token_helpers import arg_str
17+
from pyupgrade._token_helpers import find_op
18+
from pyupgrade._token_helpers import parse_call_args
19+
from pyupgrade._token_helpers import replace_name
20+
21+
22+
class _Target(NamedTuple):
23+
target: str
24+
module: str | None
25+
name: str
26+
min_version: Version
27+
28+
29+
_TARGETS = (
30+
_Target('OSError', 'mmap', 'error', (3,)),
31+
_Target('OSError', 'os', 'error', (3,)),
32+
_Target('OSError', 'select', 'error', (3,)),
33+
_Target('OSError', 'socket', 'error', (3,)),
34+
_Target('OSError', None, 'IOError', (3,)),
35+
_Target('OSError', None, 'EnvironmentError', (3,)),
36+
_Target('OSError', None, 'WindowsError', (3,)),
37+
_Target('TimeoutError', 'socket', 'timeout', (3, 10)),
38+
_Target('TimeoutError', 'asyncio', 'TimeoutError', (3, 11)),
39+
)
40+
41+
42+
def _fix_except(
43+
i: int,
44+
tokens: list[Token],
45+
*,
46+
at_idx: dict[int, _Target],
47+
) -> None:
48+
# find all the arg strs in the tuple
49+
except_index = i
50+
while tokens[except_index].src != 'except':
51+
except_index -= 1
52+
start = find_op(tokens, except_index, '(')
53+
func_args, end = parse_call_args(tokens, start)
54+
55+
# save the exceptions and remove the block
56+
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
57+
del tokens[start:end]
58+
59+
# rewrite the block without dupes
60+
args = []
61+
for i, arg in enumerate(arg_strs):
62+
target = at_idx.get(i)
63+
if target is not None:
64+
args.append(target.target)
65+
else:
66+
args.append(arg)
67+
68+
unique_args = tuple(dict.fromkeys(args))
69+
70+
if len(unique_args) > 1:
71+
joined = '({})'.format(', '.join(unique_args))
72+
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
73+
joined = f' {unique_args[0]}'
74+
else:
75+
joined = unique_args[0]
76+
77+
new = Token('CODE', joined)
78+
tokens.insert(start, new)
79+
80+
81+
def _get_rewrite(
82+
node: ast.AST,
83+
state: State,
84+
targets: list[_Target],
85+
) -> _Target | None:
86+
for target in targets:
87+
if (
88+
target.module is None and
89+
isinstance(node, ast.Name) and
90+
node.id == target.name
91+
):
92+
return target
93+
elif (
94+
target.module is not None and
95+
isinstance(node, ast.Name) and
96+
node.id == target.name and
97+
node.id in state.from_imports[target.module]
98+
):
99+
return target
100+
elif (
101+
target.module is not None and
102+
isinstance(node, ast.Attribute) and
103+
isinstance(node.value, ast.Name) and
104+
node.attr == target.name and
105+
node.value.id == target.module
106+
):
107+
return target
108+
else:
109+
return None
110+
111+
112+
def _alias_cbs(
113+
node: ast.expr,
114+
state: State,
115+
targets: list[_Target],
116+
) -> Iterable[tuple[Offset, TokenFunc]]:
117+
target = _get_rewrite(node, state, targets)
118+
if target is not None:
119+
func = functools.partial(
120+
replace_name,
121+
name=target.name,
122+
new=target.target,
123+
)
124+
yield ast_to_offset(node), func
125+
126+
127+
@register(ast.Raise)
128+
def visit_Raise(
129+
state: State,
130+
node: ast.Raise,
131+
parent: ast.AST,
132+
) -> Iterable[tuple[Offset, TokenFunc]]:
133+
targets = [
134+
target for target in _TARGETS
135+
if state.settings.min_version >= target.min_version
136+
]
137+
if node.exc is not None:
138+
yield from _alias_cbs(node.exc, state, targets)
139+
if isinstance(node.exc, ast.Call):
140+
yield from _alias_cbs(node.exc.func, state, targets)
141+
142+
143+
@register(ast.Try)
144+
def visit_Try(
145+
state: State,
146+
node: ast.Try,
147+
parent: ast.AST,
148+
) -> Iterable[tuple[Offset, TokenFunc]]:
149+
targets = [
150+
target for target in _TARGETS
151+
if state.settings.min_version >= target.min_version
152+
]
153+
for handler in node.handlers:
154+
if isinstance(handler.type, ast.Tuple):
155+
at_idx = {}
156+
for i, elt in enumerate(handler.type.elts):
157+
target = _get_rewrite(elt, state, targets)
158+
if target is not None:
159+
at_idx[i] = target
160+
161+
if at_idx:
162+
func = functools.partial(_fix_except, at_idx=at_idx)
163+
yield ast_to_offset(handler.type), func
164+
elif handler.type is not None:
165+
yield from _alias_cbs(handler.type, state, targets)

pyupgrade/_plugins/oserror_aliases.py

-136
This file was deleted.

0 commit comments

Comments
 (0)