Skip to content

Commit d1518e8

Browse files
authored
Merge pull request #901 from UnknownPlatypus/defaultdict-lambdas
Defaultdict lambdas
2 parents 72041b1 + f001635 commit d1518e8

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ Sample `.pre-commit-config.yaml`:
5656
+{a: b for a, b in y}
5757
```
5858

59+
### Replace unnecessary lambdas in `collections.defaultdict` calls
60+
61+
```diff
62+
-defaultdict(lambda: [])
63+
+defaultdict(list)
64+
-defaultdict(lambda: list())
65+
+defaultdict(list)
66+
-defaultdict(lambda: {})
67+
+defaultdict(dict)
68+
-defaultdict(lambda: dict())
69+
+defaultdict(dict)
70+
-defaultdict(lambda: ())
71+
+defaultdict(tuple)
72+
-defaultdict(lambda: tuple())
73+
+defaultdict(tuple)
74+
-defaultdict(lambda: set())
75+
+defaultdict(set)
76+
-defaultdict(lambda: 0)
77+
+defaultdict(int)
78+
-defaultdict(lambda: 0.0)
79+
+defaultdict(float)
80+
-defaultdict(lambda: 0j)
81+
+defaultdict(complex)
82+
-defaultdict(lambda: '')
83+
+defaultdict(str)
84+
```
5985

6086
### Format Specifiers
6187

pyupgrade/_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class State(NamedTuple):
3939
RECORD_FROM_IMPORTS = frozenset((
4040
'__future__',
4141
'asyncio',
42+
'collections',
4243
'functools',
4344
'mmap',
4445
'os',
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import functools
5+
from typing import Iterable
6+
7+
from tokenize_rt import Offset
8+
from tokenize_rt import Token
9+
10+
from pyupgrade._ast_helpers import ast_to_offset
11+
from pyupgrade._ast_helpers import is_name_attr
12+
from pyupgrade._data import register
13+
from pyupgrade._data import State
14+
from pyupgrade._data import TokenFunc
15+
from pyupgrade._token_helpers import find_op
16+
from pyupgrade._token_helpers import parse_call_args
17+
18+
19+
def _eligible_lambda_replacement(lambda_expr: ast.Lambda) -> str | None:
20+
if isinstance(lambda_expr.body, ast.Constant):
21+
if lambda_expr.body.value == 0:
22+
return type(lambda_expr.body.value).__name__
23+
elif lambda_expr.body.value == '':
24+
return 'str'
25+
else:
26+
return None
27+
elif isinstance(lambda_expr.body, ast.List) and not lambda_expr.body.elts:
28+
return 'list'
29+
elif isinstance(lambda_expr.body, ast.Tuple) and not lambda_expr.body.elts:
30+
return 'tuple'
31+
elif isinstance(lambda_expr.body, ast.Dict) and not lambda_expr.body.keys:
32+
return 'dict'
33+
elif (
34+
isinstance(lambda_expr.body, ast.Call) and
35+
isinstance(lambda_expr.body.func, ast.Name) and
36+
not lambda_expr.body.args and
37+
not lambda_expr.body.keywords and
38+
lambda_expr.body.func.id in {'dict', 'list', 'set', 'tuple'}
39+
):
40+
return lambda_expr.body.func.id
41+
else:
42+
return None
43+
44+
45+
def _fix_defaultdict_first_arg(
46+
i: int,
47+
tokens: list[Token],
48+
*,
49+
replacement: str,
50+
) -> None:
51+
start = find_op(tokens, i, '(')
52+
func_args, end = parse_call_args(tokens, start)
53+
54+
tokens[slice(*func_args[0])] = [Token('CODE', replacement)]
55+
56+
57+
@register(ast.Call)
58+
def visit_Call(
59+
state: State,
60+
node: ast.Call,
61+
parent: ast.AST,
62+
) -> Iterable[tuple[Offset, TokenFunc]]:
63+
if (
64+
is_name_attr(
65+
node.func,
66+
state.from_imports,
67+
('collections',),
68+
('defaultdict',),
69+
) and
70+
node.args and
71+
isinstance(node.args[0], ast.Lambda)
72+
):
73+
replacement = _eligible_lambda_replacement(node.args[0])
74+
if replacement is None:
75+
return
76+
77+
func = functools.partial(
78+
_fix_defaultdict_first_arg,
79+
replacement=replacement,
80+
)
81+
yield ast_to_offset(node), func
+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pyupgrade._data import Settings
6+
from pyupgrade._main import _fix_plugins
7+
8+
9+
@pytest.mark.parametrize(
10+
('s',),
11+
(
12+
pytest.param(
13+
'from collections import defaultdict as dd\n\n'
14+
'dd(lambda: set())\n',
15+
id='not following as imports',
16+
),
17+
pytest.param(
18+
'from collections2 import defaultdict\n\n'
19+
'dd(lambda: dict())\n',
20+
id='not following unknown import',
21+
),
22+
pytest.param(
23+
'from .collections import defaultdict\n'
24+
'defaultdict(lambda: list())\n',
25+
id='relative imports',
26+
),
27+
pytest.param(
28+
'from collections import defaultdict\n\n'
29+
'defaultdict(lambda: {1}))\n',
30+
id='non empty set',
31+
),
32+
pytest.param(
33+
'from collections import defaultdict\n\n'
34+
'defaultdict(lambda: [1]))\n'
35+
'defaultdict(lambda: list([1])))\n',
36+
id='non empty list',
37+
),
38+
pytest.param(
39+
'from collections import defaultdict\n\n'
40+
'defaultdict(lambda: {1: 2})\n',
41+
id='non empty dict, literal',
42+
),
43+
pytest.param(
44+
'from collections import defaultdict\n\n'
45+
'defaultdict(lambda: dict([(1,2),])))\n',
46+
id='non empty dict, call with args',
47+
),
48+
pytest.param(
49+
'from collections import defaultdict\n\n'
50+
'defaultdict(lambda: dict(a=[1]))\n',
51+
id='non empty dict, call with kwargs',
52+
),
53+
pytest.param(
54+
'from collections import defaultdict\n\n'
55+
'defaultdict(lambda: (1,))\n',
56+
id='non empty tuple, literal',
57+
),
58+
pytest.param(
59+
'from collections import defaultdict\n\n'
60+
'defaultdict(lambda: tuple([1]))\n',
61+
id='non empty tuple, calls with arg',
62+
),
63+
pytest.param(
64+
'from collections import defaultdict\n\n'
65+
'defaultdict(lambda: "AAA")\n'
66+
'defaultdict(lambda: \'BBB\')\n',
67+
id='non empty string',
68+
),
69+
pytest.param(
70+
'from collections import defaultdict\n\n'
71+
'defaultdict(lambda: 10)\n'
72+
'defaultdict(lambda: -2)\n',
73+
id='non zero integer',
74+
),
75+
pytest.param(
76+
'from collections import defaultdict\n\n'
77+
'defaultdict(lambda: 0.2)\n'
78+
'defaultdict(lambda: 0.00000001)\n'
79+
'defaultdict(lambda: -2.3)\n',
80+
id='non zero float',
81+
),
82+
pytest.param(
83+
'import collections\n'
84+
'collections.defaultdict(lambda: None)\n',
85+
id='lambda: None is not equivalent to defaultdict()',
86+
),
87+
),
88+
)
89+
def test_fix_noop(s):
90+
assert _fix_plugins(s, settings=Settings()) == s
91+
92+
93+
@pytest.mark.parametrize(
94+
('s', 'expected'),
95+
(
96+
pytest.param(
97+
'from collections import defaultdict\n\n'
98+
'defaultdict(lambda: set())\n',
99+
'from collections import defaultdict\n\n'
100+
'defaultdict(set)\n',
101+
id='call with attr, set()',
102+
),
103+
pytest.param(
104+
'from collections import defaultdict\n\n'
105+
'defaultdict(lambda: list())\n',
106+
'from collections import defaultdict\n\n'
107+
'defaultdict(list)\n',
108+
id='call with attr, list()',
109+
),
110+
pytest.param(
111+
'from collections import defaultdict\n\n'
112+
'defaultdict(lambda: dict())\n',
113+
'from collections import defaultdict\n\n'
114+
'defaultdict(dict)\n',
115+
id='call with attr, dict()',
116+
),
117+
pytest.param(
118+
'from collections import defaultdict\n\n'
119+
'defaultdict(lambda: tuple())\n',
120+
'from collections import defaultdict\n\n'
121+
'defaultdict(tuple)\n',
122+
id='call with attr, tuple()',
123+
),
124+
pytest.param(
125+
'from collections import defaultdict\n\n'
126+
'defaultdict(lambda: [])\n',
127+
'from collections import defaultdict\n\n'
128+
'defaultdict(list)\n',
129+
id='call with attr, []',
130+
),
131+
pytest.param(
132+
'from collections import defaultdict\n\n'
133+
'defaultdict(lambda: {})\n',
134+
'from collections import defaultdict\n\n'
135+
'defaultdict(dict)\n',
136+
id='call with attr, {}',
137+
),
138+
pytest.param(
139+
'from collections import defaultdict\n\n'
140+
'defaultdict(lambda: ())\n',
141+
'from collections import defaultdict\n\n'
142+
'defaultdict(tuple)\n',
143+
id='call with attr, ()',
144+
),
145+
pytest.param(
146+
'from collections import defaultdict\n\n'
147+
'defaultdict(lambda: "")\n',
148+
'from collections import defaultdict\n\n'
149+
'defaultdict(str)\n',
150+
id='call with attr, empty string (double quote)',
151+
),
152+
pytest.param(
153+
'from collections import defaultdict\n\n'
154+
'defaultdict(lambda: \'\')\n',
155+
'from collections import defaultdict\n\n'
156+
'defaultdict(str)\n',
157+
id='call with attr, empty string (single quote)',
158+
),
159+
pytest.param(
160+
'from collections import defaultdict\n\n'
161+
'defaultdict(lambda: 0)\n',
162+
'from collections import defaultdict\n\n'
163+
'defaultdict(int)\n',
164+
id='call with attr, int',
165+
),
166+
pytest.param(
167+
'from collections import defaultdict\n\n'
168+
'defaultdict(lambda: 0.0)\n',
169+
'from collections import defaultdict\n\n'
170+
'defaultdict(float)\n',
171+
id='call with attr, float',
172+
),
173+
pytest.param(
174+
'from collections import defaultdict\n\n'
175+
'defaultdict(lambda: 0.0000)\n',
176+
'from collections import defaultdict\n\n'
177+
'defaultdict(float)\n',
178+
id='call with attr, long float',
179+
),
180+
pytest.param(
181+
'from collections import defaultdict\n\n'
182+
'defaultdict(lambda: [], {1: []})\n',
183+
'from collections import defaultdict\n\n'
184+
'defaultdict(list, {1: []})\n',
185+
id='defauldict with kwargs',
186+
),
187+
pytest.param(
188+
'import collections\n\n'
189+
'collections.defaultdict(lambda: set())\n'
190+
'collections.defaultdict(lambda: list())\n'
191+
'collections.defaultdict(lambda: dict())\n'
192+
'collections.defaultdict(lambda: tuple())\n'
193+
'collections.defaultdict(lambda: [])\n'
194+
'collections.defaultdict(lambda: {})\n'
195+
'collections.defaultdict(lambda: "")\n'
196+
'collections.defaultdict(lambda: \'\')\n'
197+
'collections.defaultdict(lambda: 0)\n'
198+
'collections.defaultdict(lambda: 0.0)\n'
199+
'collections.defaultdict(lambda: 0.00000)\n'
200+
'collections.defaultdict(lambda: 0j)\n',
201+
'import collections\n\n'
202+
'collections.defaultdict(set)\n'
203+
'collections.defaultdict(list)\n'
204+
'collections.defaultdict(dict)\n'
205+
'collections.defaultdict(tuple)\n'
206+
'collections.defaultdict(list)\n'
207+
'collections.defaultdict(dict)\n'
208+
'collections.defaultdict(str)\n'
209+
'collections.defaultdict(str)\n'
210+
'collections.defaultdict(int)\n'
211+
'collections.defaultdict(float)\n'
212+
'collections.defaultdict(float)\n'
213+
'collections.defaultdict(complex)\n',
214+
id='call with attr',
215+
),
216+
),
217+
)
218+
def test_fix_defaultdict(s, expected):
219+
ret = _fix_plugins(s, settings=Settings())
220+
assert ret == expected

0 commit comments

Comments
 (0)