Skip to content

Commit 4e28911

Browse files
authored
Merge pull request #948 from asottile/pep-696
rewrite TypeVar defaults for Generator / AsyncGenerator
2 parents bc45bf1 + d17f461 commit 4e28911

6 files changed

+219
-2
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,24 @@ Availability:
754754
...
755755
```
756756

757+
### pep 696 TypeVar defaults
758+
759+
Availability:
760+
- File imports `from __future__ import annotations`
761+
- Unless `--keep-runtime-typing` is passed on the commandline.
762+
- `--py313-plus` is passed on the commandline.
763+
764+
```diff
765+
-def f() -> Generator[int, None, None]:
766+
+def f() -> Generator[int]:
767+
yield 1
768+
```
769+
770+
```diff
771+
-async def f() -> AsyncGenerator[int, None]:
772+
+async def f() -> AsyncGenerator[int]:
773+
yield 1
774+
```
757775

758776
### remove quoted annotations
759777

pyupgrade/_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class State(NamedTuple):
4040
'__future__',
4141
'asyncio',
4242
'collections',
43+
'collections.abc',
4344
'functools',
4445
'mmap',
4546
'os',

pyupgrade/_plugins/legacy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self) -> None:
9191
self.yield_offsets: set[Offset] = set()
9292

9393
@contextlib.contextmanager
94-
def _scope(self, node: ast.AST) -> Generator[None, None, None]:
94+
def _scope(self, node: ast.AST) -> Generator[None]:
9595
self._scopes.append(Scope(node))
9696
try:
9797
yield

pyupgrade/_plugins/percent_format.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _must_match(regex: Pattern[str], string: str, pos: int) -> Match[str]:
4646

4747

4848
def _parse_percent_format(s: str) -> tuple[PercentFormat, ...]:
49-
def _parse_inner() -> Generator[PercentFormat, None, None]:
49+
def _parse_inner() -> Generator[PercentFormat]:
5050
string_start = 0
5151
string_end = 0
5252
in_fmt = False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import Iterable
5+
6+
from tokenize_rt import Offset
7+
from tokenize_rt import Token
8+
9+
from pyupgrade._ast_helpers import ast_to_offset
10+
from pyupgrade._ast_helpers import is_name_attr
11+
from pyupgrade._data import register
12+
from pyupgrade._data import State
13+
from pyupgrade._data import TokenFunc
14+
from pyupgrade._token_helpers import find_op
15+
from pyupgrade._token_helpers import parse_call_args
16+
17+
18+
def _fix_typevar_default(i: int, tokens: list[Token]) -> None:
19+
j = find_op(tokens, i, '[')
20+
args, end = parse_call_args(tokens, j)
21+
# remove the trailing `None` arguments
22+
del tokens[args[0][1]:args[-1][1]]
23+
24+
25+
def _should_rewrite(state: State) -> bool:
26+
return (
27+
state.settings.min_version >= (3, 13) or (
28+
not state.settings.keep_runtime_typing and
29+
state.in_annotation and
30+
'annotations' in state.from_imports['__future__']
31+
)
32+
)
33+
34+
35+
def _is_none(node: ast.AST) -> bool:
36+
return isinstance(node, ast.Constant) and node.value is None
37+
38+
39+
@register(ast.Subscript)
40+
def visit_Subscript(
41+
state: State,
42+
node: ast.Subscript,
43+
parent: ast.AST,
44+
) -> Iterable[tuple[Offset, TokenFunc]]:
45+
if not _should_rewrite(state):
46+
return
47+
48+
if (
49+
is_name_attr(
50+
node.value,
51+
state.from_imports,
52+
('collections.abc', 'typing', 'typing_extensions'),
53+
('Generator',),
54+
) and
55+
isinstance(node.slice, ast.Tuple) and
56+
len(node.slice.elts) == 3 and
57+
_is_none(node.slice.elts[1]) and
58+
_is_none(node.slice.elts[2])
59+
):
60+
yield ast_to_offset(node), _fix_typevar_default
61+
elif (
62+
is_name_attr(
63+
node.value,
64+
state.from_imports,
65+
('collections.abc', 'typing', 'typing_extensions'),
66+
('AsyncGenerator',),
67+
) and
68+
isinstance(node.slice, ast.Tuple) and
69+
len(node.slice.elts) == 2 and
70+
_is_none(node.slice.elts[1])
71+
):
72+
yield ast_to_offset(node), _fix_typevar_default
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pyupgrade._data import Settings
6+
from pyupgrade._main import _fix_plugins
7+
8+
9+
@pytest.mark.parametrize(
10+
('s', 'version'),
11+
(
12+
pytest.param(
13+
'from collections.abc import Generator\n'
14+
'def f() -> Generator[int, None, None]: yield 1\n',
15+
(3, 12),
16+
id='not 3.13+, no __future__.annotations',
17+
),
18+
pytest.param(
19+
'from __future__ import annotations\n'
20+
'from collections.abc import Generator\n'
21+
'def f() -> Generator[int]: yield 1\n',
22+
(3, 12),
23+
id='already converted!',
24+
),
25+
pytest.param(
26+
'from __future__ import annotations\n'
27+
'from collections.abc import Generator\n'
28+
'def f() -> Generator[int, int, None]: yield 1\n'
29+
'def g() -> Generator[int, int, int]: yield 1\n',
30+
(3, 12),
31+
id='non-None send/return type',
32+
),
33+
),
34+
)
35+
def test_fix_pep696_noop(s, version):
36+
assert _fix_plugins(s, settings=Settings(min_version=version)) == s
37+
38+
39+
def test_fix_pep696_noop_keep_runtime_typing():
40+
settings = Settings(min_version=(3, 12), keep_runtime_typing=True)
41+
s = '''\
42+
from __future__ import annotations
43+
from collections.abc import Generator
44+
def f() -> Generator[int, None, None]: yield 1
45+
'''
46+
assert _fix_plugins(s, settings=settings) == s
47+
48+
49+
@pytest.mark.parametrize(
50+
('s', 'expected'),
51+
(
52+
pytest.param(
53+
'from __future__ import annotations\n'
54+
'from typing import Generator\n'
55+
'def f() -> Generator[int, None, None]: yield 1\n',
56+
57+
'from __future__ import annotations\n'
58+
'from collections.abc import Generator\n'
59+
'def f() -> Generator[int]: yield 1\n',
60+
61+
id='typing.Generator',
62+
),
63+
pytest.param(
64+
'from __future__ import annotations\n'
65+
'from typing_extensions import Generator\n'
66+
'def f() -> Generator[int, None, None]: yield 1\n',
67+
68+
'from __future__ import annotations\n'
69+
'from typing_extensions import Generator\n'
70+
'def f() -> Generator[int]: yield 1\n',
71+
72+
id='typing_extensions.Generator',
73+
),
74+
pytest.param(
75+
'from __future__ import annotations\n'
76+
'from collections.abc import Generator\n'
77+
'def f() -> Generator[int, None, None]: yield 1\n',
78+
79+
'from __future__ import annotations\n'
80+
'from collections.abc import Generator\n'
81+
'def f() -> Generator[int]: yield 1\n',
82+
83+
id='collections.abc.Generator',
84+
),
85+
pytest.param(
86+
'from __future__ import annotations\n'
87+
'from collections.abc import AsyncGenerator\n'
88+
'async def f() -> AsyncGenerator[int, None]: yield 1\n',
89+
90+
'from __future__ import annotations\n'
91+
'from collections.abc import AsyncGenerator\n'
92+
'async def f() -> AsyncGenerator[int]: yield 1\n',
93+
94+
id='collections.abc.AsyncGenerator',
95+
),
96+
),
97+
)
98+
def test_fix_pep696_with_future_annotations(s, expected):
99+
assert _fix_plugins(s, settings=Settings(min_version=(3, 12))) == expected
100+
101+
102+
@pytest.mark.parametrize(
103+
('s', 'expected'),
104+
(
105+
pytest.param(
106+
'from collections.abc import Generator\n'
107+
'def f() -> Generator[int, None, None]: yield 1\n',
108+
109+
'from collections.abc import Generator\n'
110+
'def f() -> Generator[int]: yield 1\n',
111+
112+
id='Generator',
113+
),
114+
pytest.param(
115+
'from collections.abc import AsyncGenerator\n'
116+
'async def f() -> AsyncGenerator[int, None]: yield 1\n',
117+
118+
'from collections.abc import AsyncGenerator\n'
119+
'async def f() -> AsyncGenerator[int]: yield 1\n',
120+
121+
id='AsyncGenerator',
122+
),
123+
),
124+
)
125+
def test_fix_pep696_with_3_13(s, expected):
126+
assert _fix_plugins(s, settings=Settings(min_version=(3, 13))) == expected

0 commit comments

Comments
 (0)