|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import ast |
| 4 | +import functools |
| 5 | +from typing import Iterable |
| 6 | +from typing import NamedTuple |
| 7 | + |
| 8 | +from tokenize_rt import Offset |
| 9 | +from tokenize_rt import Token |
| 10 | + |
| 11 | +from pyupgrade._ast_helpers import ast_to_offset |
| 12 | +from pyupgrade._data import register |
| 13 | +from pyupgrade._data import State |
| 14 | +from pyupgrade._data import TokenFunc |
| 15 | +from pyupgrade._data import Version |
| 16 | +from pyupgrade._token_helpers import arg_str |
| 17 | +from pyupgrade._token_helpers import find_op |
| 18 | +from pyupgrade._token_helpers import parse_call_args |
| 19 | +from pyupgrade._token_helpers import replace_name |
| 20 | + |
| 21 | + |
| 22 | +class _Target(NamedTuple): |
| 23 | + target: str |
| 24 | + module: str | None |
| 25 | + name: str |
| 26 | + min_version: Version |
| 27 | + |
| 28 | + |
| 29 | +_TARGETS = ( |
| 30 | + _Target('OSError', 'mmap', 'error', (3,)), |
| 31 | + _Target('OSError', 'os', 'error', (3,)), |
| 32 | + _Target('OSError', 'select', 'error', (3,)), |
| 33 | + _Target('OSError', 'socket', 'error', (3,)), |
| 34 | + _Target('OSError', None, 'IOError', (3,)), |
| 35 | + _Target('OSError', None, 'EnvironmentError', (3,)), |
| 36 | + _Target('OSError', None, 'WindowsError', (3,)), |
| 37 | + _Target('TimeoutError', 'socket', 'timeout', (3, 10)), |
| 38 | + _Target('TimeoutError', 'asyncio', 'TimeoutError', (3, 11)), |
| 39 | +) |
| 40 | + |
| 41 | + |
| 42 | +def _fix_except( |
| 43 | + i: int, |
| 44 | + tokens: list[Token], |
| 45 | + *, |
| 46 | + at_idx: dict[int, _Target], |
| 47 | +) -> None: |
| 48 | + # find all the arg strs in the tuple |
| 49 | + except_index = i |
| 50 | + while tokens[except_index].src != 'except': |
| 51 | + except_index -= 1 |
| 52 | + start = find_op(tokens, except_index, '(') |
| 53 | + func_args, end = parse_call_args(tokens, start) |
| 54 | + |
| 55 | + # save the exceptions and remove the block |
| 56 | + arg_strs = [arg_str(tokens, *arg) for arg in func_args] |
| 57 | + del tokens[start:end] |
| 58 | + |
| 59 | + # rewrite the block without dupes |
| 60 | + args = [] |
| 61 | + for i, arg in enumerate(arg_strs): |
| 62 | + target = at_idx.get(i) |
| 63 | + if target is not None: |
| 64 | + args.append(target.target) |
| 65 | + else: |
| 66 | + args.append(arg) |
| 67 | + |
| 68 | + unique_args = tuple(dict.fromkeys(args)) |
| 69 | + |
| 70 | + if len(unique_args) > 1: |
| 71 | + joined = '({})'.format(', '.join(unique_args)) |
| 72 | + elif tokens[start - 1].name != 'UNIMPORTANT_WS': |
| 73 | + joined = f' {unique_args[0]}' |
| 74 | + else: |
| 75 | + joined = unique_args[0] |
| 76 | + |
| 77 | + new = Token('CODE', joined) |
| 78 | + tokens.insert(start, new) |
| 79 | + |
| 80 | + |
| 81 | +def _get_rewrite( |
| 82 | + node: ast.AST, |
| 83 | + state: State, |
| 84 | + targets: list[_Target], |
| 85 | +) -> _Target | None: |
| 86 | + for target in targets: |
| 87 | + if ( |
| 88 | + target.module is None and |
| 89 | + isinstance(node, ast.Name) and |
| 90 | + node.id == target.name |
| 91 | + ): |
| 92 | + return target |
| 93 | + elif ( |
| 94 | + target.module is not None and |
| 95 | + isinstance(node, ast.Name) and |
| 96 | + node.id == target.name and |
| 97 | + node.id in state.from_imports[target.module] |
| 98 | + ): |
| 99 | + return target |
| 100 | + elif ( |
| 101 | + target.module is not None and |
| 102 | + isinstance(node, ast.Attribute) and |
| 103 | + isinstance(node.value, ast.Name) and |
| 104 | + node.attr == target.name and |
| 105 | + node.value.id == target.module |
| 106 | + ): |
| 107 | + return target |
| 108 | + else: |
| 109 | + return None |
| 110 | + |
| 111 | + |
| 112 | +def _alias_cbs( |
| 113 | + node: ast.expr, |
| 114 | + state: State, |
| 115 | + targets: list[_Target], |
| 116 | +) -> Iterable[tuple[Offset, TokenFunc]]: |
| 117 | + target = _get_rewrite(node, state, targets) |
| 118 | + if target is not None: |
| 119 | + func = functools.partial( |
| 120 | + replace_name, |
| 121 | + name=target.name, |
| 122 | + new=target.target, |
| 123 | + ) |
| 124 | + yield ast_to_offset(node), func |
| 125 | + |
| 126 | + |
| 127 | +@register(ast.Raise) |
| 128 | +def visit_Raise( |
| 129 | + state: State, |
| 130 | + node: ast.Raise, |
| 131 | + parent: ast.AST, |
| 132 | +) -> Iterable[tuple[Offset, TokenFunc]]: |
| 133 | + targets = [ |
| 134 | + target for target in _TARGETS |
| 135 | + if state.settings.min_version >= target.min_version |
| 136 | + ] |
| 137 | + if node.exc is not None: |
| 138 | + yield from _alias_cbs(node.exc, state, targets) |
| 139 | + if isinstance(node.exc, ast.Call): |
| 140 | + yield from _alias_cbs(node.exc.func, state, targets) |
| 141 | + |
| 142 | + |
| 143 | +@register(ast.Try) |
| 144 | +def visit_Try( |
| 145 | + state: State, |
| 146 | + node: ast.Try, |
| 147 | + parent: ast.AST, |
| 148 | +) -> Iterable[tuple[Offset, TokenFunc]]: |
| 149 | + targets = [ |
| 150 | + target for target in _TARGETS |
| 151 | + if state.settings.min_version >= target.min_version |
| 152 | + ] |
| 153 | + for handler in node.handlers: |
| 154 | + if isinstance(handler.type, ast.Tuple): |
| 155 | + at_idx = {} |
| 156 | + for i, elt in enumerate(handler.type.elts): |
| 157 | + target = _get_rewrite(elt, state, targets) |
| 158 | + if target is not None: |
| 159 | + at_idx[i] = target |
| 160 | + |
| 161 | + if at_idx: |
| 162 | + func = functools.partial(_fix_except, at_idx=at_idx) |
| 163 | + yield ast_to_offset(handler.type), func |
| 164 | + elif handler.type is not None: |
| 165 | + yield from _alias_cbs(handler.type, state, targets) |
0 commit comments