diff --git a/README.md b/README.md index 05d1efc..9593050 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,33 @@ This has the following benefits: pass ``` + +### trailling comma for PEP-695 type aliases + +```diff + def f[ +- T ++ T, + ](x: T) -> T: + return x +``` + +```diff + class A[ +- K ++ K, + ]: + def __init__(self, x: T) -> None: + self.x = x +``` + +```diff + type ListOrSet[ +- T ++ T, +] = list[T] | set[T] +``` + ### unhug trailing paren ```diff diff --git a/add_trailing_comma/_plugins/pep695.py b/add_trailing_comma/_plugins/pep695.py new file mode 100644 index 0000000..46fc639 --- /dev/null +++ b/add_trailing_comma/_plugins/pep695.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import ast +import sys +from typing import Iterable + +from tokenize_rt import Offset +from tokenize_rt import Token + +from add_trailing_comma._ast_helpers import ast_to_offset +from add_trailing_comma._data import register +from add_trailing_comma._data import State +from add_trailing_comma._data import TokenFunc +from add_trailing_comma._token_helpers import find_simple +from add_trailing_comma._token_helpers import fix_brace + + +if sys.version_info >= (3, 12): # pragma: >=3.12 cover + def _fix_pep695( + i: int, + tokens: list[Token], + ) -> None: + for n in range(i, len(tokens)): + token = tokens[n] + if token.name == 'OP' and token.src == '[': + return fix_brace( + tokens, + find_simple(n, tokens), + add_comma=True, + remove_comma=True, + ) + else: + raise AssertionError('Past end?') + + def visit_pep695( + state: State, + node: ast.TypeAlias | ast.ClassDef | ast.FunctionDef, + ) -> Iterable[tuple[Offset, TokenFunc]]: + if node.type_params: + yield ast_to_offset(node), _fix_pep695 + + register(ast.TypeAlias)(visit_pep695) + register(ast.ClassDef)(visit_pep695) + register(ast.FunctionDef)(visit_pep695) diff --git a/tests/features/pep695_test.py b/tests/features/pep695_test.py new file mode 100644 index 0000000..2dced5f --- /dev/null +++ b/tests/features/pep695_test.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import sys + +import pytest + +from add_trailing_comma._main import _fix_src + + +@pytest.mark.parametrize( + 's', + ( + pytest.param( + 'class A[K]:\n' + ' ...\n', + id='single line classdef', + ), + pytest.param( + 'def not_none[K](v: K) -> K:\n' + ' ...\n', + id='single line functiondef', + ), + pytest.param( + 'type ListOrSet[T] = list[T] | set[T]', + id='single line generic type alias', + ), + pytest.param( + 'type ListOrSet = list[str] | set[int]', + id='no type-param type alias', + ), + ), +) +def test_noop(s): + assert _fix_src(s) == s + + +@pytest.mark.xfail(sys.version_info < (3, 12), reason='py312+') +@pytest.mark.parametrize( + ('s', 'e'), + ( + pytest.param( + 'class ClassA[\n' + ' T: str\n' + ']:\n' + ' ...', + + 'class ClassA[\n' + ' T: str,\n' + ']:\n' + ' ...', + id='multiline classdef', + ), + pytest.param( + 'def f[\n' + ' T\n' + '](x: T) -> T:\n' + ' ...', + + 'def f[\n' + ' T,\n' + '](x: T) -> T:\n' + ' ...', + id='multiline functiondef', + ), + pytest.param( + 'type ListOrSet[\n' + ' T,\n' + ' K\n' + '] = list[T] | set[K]', + 'type ListOrSet[\n' + ' T,\n' + ' K,\n' + '] = list[T] | set[K]', + id='multiline generic type alias', + ), + pytest.param( + 'def f[\n' + ' T: (\n' + ' "ForwardReference",\n' + ' bytes\n' + ' )\n' + '](x: T) -> T:\n' + ' ...', + + 'def f[\n' + ' T: (\n' + ' "ForwardReference",\n' + ' bytes,\n' + ' ),\n' + '](x: T) -> T:\n' + ' ...', + id='multiline function constrained types', + ), + pytest.param( + 'class ClassB[\n' + ' T: (\n' + ' "ForwardReference",\n' + ' bytes\n' + ' )\n' + ']:\n' + ' ...\n', + + 'class ClassB[\n' + ' T: (\n' + ' "ForwardReference",\n' + ' bytes,\n' + ' ),\n' + ']:\n' + ' ...\n', + id='multiline class constrained types', + ), + ), +) +def test_fix(s, e): + assert _fix_src(s) == e