Skip to content

Commit 679c45d

Browse files
authored
Support dict unpacking with conditional expression (#199)
* Resolve conditional dict expansions when test constant * Fix up type and flake8 errors * Fix up flake8 tests * Update agents for better testing * Improve conditional dict error msg and add nested ifexp test (#200) * Remove unneeded recursion
1 parent 2606c3c commit 679c45d

File tree

5 files changed

+136
-9
lines changed

5 files changed

+136
-9
lines changed

AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
* Prefer using `Optional[int]` rather than `int | None`.
44
* If there is an error the user will see (e.g. `ValueError`), make sure there is enough context in the message for the user. For example, if it is during an expression parse, include the `ast.unparse(a)` in the error message.
5+
* Before finishing, make sure `flake8` runs without errors on source and test files.
6+
* Before finishing, also make sure `black` runs without modifying files.

func_adl/ast/syntatic_sugar.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,60 @@ def convert_call_to_dict(
173173
values=arg_values,
174174
)
175175

176+
def _merge_into(self, target: ast.expr, add: ast.Dict) -> ast.expr:
177+
"""Merge ``add`` dictionary into ``target`` which may itself be a
178+
dictionary or an if-expression containing dictionaries."""
179+
180+
if isinstance(target, ast.Dict):
181+
return ast.Dict(keys=target.keys + add.keys, values=target.values + add.values)
182+
else:
183+
return target
184+
185+
def visit_Dict(self, node: ast.Dict) -> Any:
186+
"""Flatten ``**`` expansions in dictionary literals.
187+
188+
If the starred expression is a dictionary it is merged directly. If
189+
the expression is an ``if`` with both branches being dictionaries and
190+
the test is a constant, it is resolved at transformation time. If
191+
the test is not resolvable, an error is raised as the back end cannot
192+
translate it."""
193+
194+
a = self.generic_visit(node)
195+
assert isinstance(a, ast.Dict)
196+
197+
base_keys: List[Optional[ast.expr]] = []
198+
base_values: List[ast.expr] = []
199+
expansions: List[ast.expr] = []
200+
for k, v in zip(a.keys, a.values):
201+
if k is None:
202+
expansions.append(v)
203+
else:
204+
base_keys.append(k)
205+
base_values.append(v)
206+
207+
result: ast.AST = ast.Dict(keys=base_keys, values=base_values)
208+
209+
for e in expansions:
210+
if isinstance(e, ast.Dict):
211+
result = self._merge_into(result, e)
212+
elif (
213+
isinstance(e, ast.IfExp)
214+
and isinstance(e.body, ast.Dict)
215+
and isinstance(e.orelse, ast.Dict)
216+
):
217+
if isinstance(e.test, ast.Constant):
218+
branch = e.body if bool(e.test.value) else e.orelse
219+
result = self._merge_into(result, branch)
220+
else:
221+
raise ValueError(
222+
"Conditional dictionary expansion requires a constant test"
223+
f" - {ast.unparse(e)}"
224+
)
225+
else:
226+
return a
227+
228+
return result
229+
176230
def visit_Call(self, node: ast.Call) -> Any:
177231
"""
178232
This method checks if the call is to a dataclass or a named tuple and converts

tests/ast/test_meta_data.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,9 @@ def test_query_metadata_burried():
8989

9090

9191
def test_query_metadata_updated():
92-
'''This is testing code in QMetaData, but we need lookup_query_metadata which we are
93-
testing in this file'''
94-
r = (
95-
my_event()
96-
.QMetaData({"one": "two"})
97-
.QMetaData({"one": "three"})
98-
)
92+
"""This is testing code in QMetaData, but we need lookup_query_metadata which we are
93+
testing in this file"""
94+
r = my_event().QMetaData({"one": "two"}).QMetaData({"one": "three"})
9995

10096
assert lookup_query_metadata(r, "one") == "three"
10197

tests/ast/test_syntatic_sugar.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,78 @@ def test_resolve_compare_list_wrong_order():
282282
a = ast.parse("[31, 51] in p.absPdgId()")
283283
with pytest.raises(ValueError, match="Right side"):
284284
resolve_syntatic_sugar(a)
285+
286+
287+
def test_resolve_dict_star_merge():
288+
"""Dictionary unpacking should be flattened"""
289+
290+
a = ast.parse("{'n': e.EventNumber(), **{'m': e.EventNumber()}}").body[0].value # type: ignore
291+
a_resolved = resolve_syntatic_sugar(a)
292+
293+
expected = (
294+
ast.parse("{'n': e.EventNumber(), 'm': e.EventNumber()}").body[0].value # type: ignore
295+
)
296+
assert ast.unparse(a_resolved) == ast.unparse(expected)
297+
298+
299+
def test_resolve_dict_star_ifexp_true():
300+
"""Conditional dictionary unpacking should resolve when condition is True"""
301+
302+
a = (
303+
ast.parse("{'n': e.EventNumber(), **({'m': e.EventNumber()} if True else {})}")
304+
.body[0]
305+
.value # type: ignore
306+
)
307+
a_resolved = resolve_syntatic_sugar(a)
308+
309+
expected = (
310+
ast.parse("{'n': e.EventNumber(), 'm': e.EventNumber()}").body[0].value # type: ignore
311+
)
312+
assert ast.unparse(a_resolved) == ast.unparse(expected)
313+
314+
315+
def test_resolve_dict_star_ifexp_false():
316+
"""Conditional dictionary unpacking should resolve when condition is False"""
317+
318+
a = (
319+
ast.parse("{'n': e.EventNumber(), **({'m': e.EventNumber()} if False else {})}")
320+
.body[0]
321+
.value # type: ignore
322+
)
323+
a_resolved = resolve_syntatic_sugar(a)
324+
325+
expected = ast.parse("{'n': e.EventNumber()}").body[0].value # type: ignore
326+
assert ast.unparse(a_resolved) == ast.unparse(expected)
327+
328+
329+
def test_resolve_dict_star_ifexp_unknown():
330+
"""Unresolvable conditions should result in an error"""
331+
332+
a = (
333+
ast.parse("{'n': e.EventNumber(), **({'m': e.EventNumber()} if cond else {})}")
334+
.body[0]
335+
.value # type: ignore
336+
)
337+
with pytest.raises(ValueError, match="Conditional dictionary"):
338+
resolve_syntatic_sugar(a)
339+
340+
341+
def test_resolve_dict_star_ifexp_nested():
342+
"""Nested conditional dictionary unpacking should resolve correctly"""
343+
344+
a = (
345+
ast.parse(
346+
"{'n': e.EventNumber(), **({'m': e.EventNumber(), **({'o': e.EventNumber()} "
347+
"if True else {})} if True else {})}"
348+
)
349+
.body[0]
350+
.value # type: ignore
351+
)
352+
a_resolved = resolve_syntatic_sugar(a)
353+
354+
expected = (
355+
ast.parse("{'n': e.EventNumber(), 'm': e.EventNumber(), 'o': e.EventNumber()}")
356+
.body[0]
357+
.value # type: ignore
358+
)
359+
assert ast.unparse(a_resolved) == ast.unparse(expected)

tests/test_util_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_get_method_and_class_inherrited_template():
314314

315315
class bogus_1(Generic[T]):
316316
def fork(self) -> T:
317-
...
317+
pass
318318

319319
class bogus_2(bogus_1[int]):
320320
pass
@@ -325,7 +325,7 @@ class bogus_2(bogus_1[int]):
325325
def test_get_method_and_class_iterable():
326326
class bogus:
327327
def fork(self):
328-
...
328+
pass
329329

330330
assert get_method_and_class(Iterable[bogus], "fork") is None
331331

0 commit comments

Comments
 (0)