Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding PEP-695 type alias support #246

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,33 @@ This has the following benefits:
pass
```


### trailling comma for PEP-695 type aliases

```diff
def f[
- T
+ T,
](x: T) -> T:
return x
```

```diff
class A[
- K
+ K,
]:
def __init__(self, x: T) -> None:
self.x = x
```

```diff
type ListOrSet[
- T
+ T,
] = list[T] | set[T]
```

### unhug trailing paren

```diff
Expand Down
44 changes: 44 additions & 0 deletions add_trailing_comma/_plugins/pep695.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import ast
import sys
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from add_trailing_comma._ast_helpers import ast_to_offset
from add_trailing_comma._data import register
from add_trailing_comma._data import State
from add_trailing_comma._data import TokenFunc
from add_trailing_comma._token_helpers import find_simple
from add_trailing_comma._token_helpers import fix_brace


if sys.version_info >= (3, 12): # pragma: >=3.12 cover
def _fix_pep695(
i: int,
tokens: list[Token],
) -> None:
for n in range(i, len(tokens)):
token = tokens[n]
if token.name == 'OP' and token.src == '[':
return fix_brace(
tokens,
find_simple(n, tokens),
add_comma=True,
remove_comma=True,
)
else:
raise AssertionError('Past end?')

def visit_pep695(
state: State,
node: ast.TypeAlias | ast.ClassDef | ast.FunctionDef,
) -> Iterable[tuple[Offset, TokenFunc]]:
if node.type_params:
yield ast_to_offset(node), _fix_pep695

register(ast.TypeAlias)(visit_pep695)
register(ast.ClassDef)(visit_pep695)
register(ast.FunctionDef)(visit_pep695)
115 changes: 115 additions & 0 deletions tests/features/pep695_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import sys

import pytest

from add_trailing_comma._main import _fix_src


@pytest.mark.parametrize(
's',
(
pytest.param(
'class A[K]:\n'
' ...\n',
id='single line classdef',
),
pytest.param(
'def not_none[K](v: K) -> K:\n'
' ...\n',
id='single line functiondef',
),
pytest.param(
'type ListOrSet[T] = list[T] | set[T]',
id='single line generic type alias',
),
pytest.param(
'type ListOrSet = list[str] | set[int]',
id='no type-param type alias',
),
),
)
def test_noop(s):
assert _fix_src(s) == s


@pytest.mark.xfail(sys.version_info < (3, 12), reason='py312+')
@pytest.mark.parametrize(
('s', 'e'),
(
pytest.param(
'class ClassA[\n'
' T: str\n'
']:\n'
' ...',

'class ClassA[\n'
' T: str,\n'
']:\n'
' ...',
id='multiline classdef',
),
pytest.param(
'def f[\n'
' T\n'
'](x: T) -> T:\n'
' ...',

'def f[\n'
' T,\n'
'](x: T) -> T:\n'
' ...',
id='multiline functiondef',
),
pytest.param(
'type ListOrSet[\n'
' T,\n'
' K\n'
'] = list[T] | set[K]',
'type ListOrSet[\n'
' T,\n'
' K,\n'
'] = list[T] | set[K]',
id='multiline generic type alias',
),
pytest.param(
'def f[\n'
' T: (\n'
' "ForwardReference",\n'
' bytes\n'
' )\n'
'](x: T) -> T:\n'
' ...',

'def f[\n'
' T: (\n'
' "ForwardReference",\n'
' bytes,\n'
' ),\n'
asottile marked this conversation as resolved.
Show resolved Hide resolved
'](x: T) -> T:\n'
' ...',
id='multiline function constrained types',
),
pytest.param(
'class ClassB[\n'
' T: (\n'
' "ForwardReference",\n'
' bytes\n'
' )\n'
']:\n'
' ...\n',

'class ClassB[\n'
' T: (\n'
' "ForwardReference",\n'
' bytes,\n'
' ),\n'
']:\n'
' ...\n',
id='multiline class constrained types',
),
),
)
def test_fix(s, e):
assert _fix_src(s) == e