Skip to content

Commit 73a0289

Browse files
committed
Fix bug when overriding metadata
1 parent cb0f648 commit 73a0289

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

func_adl/ast/meta_data.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def extract_metadata(a: ast.AST) -> Tuple[ast.AST, List[Dict[str, str]]]:
5959

6060
def remove_empty_metadata(a: ast.AST) -> ast.AST:
6161
"""Returns a new ast with any empty `MetaData` clauses removed.
62+
The old ast is not modified.
6263
6364
Args:
6465
a (ast.AST): The AST of the query to clean up.
@@ -75,7 +76,7 @@ def visit_Call(self, node: ast.Call):
7576
if len(n.args) == 2:
7677
d = ast.literal_eval(n.args[1])
7778
if isinstance(d, dict) and len(d) == 0:
78-
return self.visit(n.args[0])
79+
return n.args[0]
7980
return n
8081

8182
return _cleaner().visit(a)
@@ -103,10 +104,14 @@ def found(self) -> Optional[Any]:
103104

104105
def generic_visit(self, node: ast.AST):
105106
q_metadata = getattr(node, "_q_metadata", None)
107+
found = False
106108
if q_metadata is not None:
107109
if metadata_name in q_metadata:
110+
found = True
108111
self._found = q_metadata[metadata_name]
109-
return super().generic_visit(node)
112+
113+
if not found:
114+
super().generic_visit(node)
110115

111116
ds_f = _finder()
112117
ds_f.visit(q.query_ast)

func_adl/object_stream.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,20 @@ def QMetaData(self, metadata: Dict[str, Any]) -> ObjectStream[T]:
199199
base_ast = self.query_ast
200200
for k, v in metadata.items():
201201
found_md = lookup_query_metadata(self, k)
202+
add_md = False
202203
if found_md is None:
204+
add_md = True
205+
elif found_md != v:
206+
logging.getLogger(__name__).info(
207+
f'Overwriting metadata "{k}" from its old value of "{found_md}" to "{v}"'
208+
)
209+
add_md = True
210+
if add_md:
203211
if first:
204212
first = False
205213
base_ast = self.MetaData({}).query_ast
206214
base_ast._q_metadata = {} # type: ignore
207215
base_ast._q_metadata[k] = v # type: ignore
208-
elif found_md != v:
209-
logging.getLogger(__name__).warning(
210-
f'Overwriting metadata "{k}" from its old value of "{found_md}" to "{v}"'
211-
)
212216

213217
return ObjectStream[T](base_ast, self.item_type)
214218

tests/ast/test_meta_data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ def test_query_metadata_burried():
8888
assert lookup_query_metadata(r, "three") == "forks"
8989

9090

91+
def test_query_metadata_updated():
92+
'This is testing code in QMetaData, but we need lookup_query_metadata which we are testing in this file'
93+
r = (
94+
my_event()
95+
.QMetaData({"one": "two"})
96+
.QMetaData({"one": "three"})
97+
)
98+
99+
assert lookup_query_metadata(r, "one") == "three"
100+
101+
91102
def test_remove_empty_metadata_empty():
92103
r = remove_empty_metadata(my_event().MetaData({}).value())
93104
assert "MetaData" not in ast.dump(r)
@@ -96,3 +107,19 @@ def test_remove_empty_metadata_empty():
96107
def test_remove_empty_metadata_not_empty():
97108
r = remove_empty_metadata(my_event().MetaData({"hi": "there"}).value())
98109
assert "MetaData" in ast.dump(r)
110+
111+
112+
def test_remove_metadata_no_change():
113+
"Removing metadata should not alter original query"
114+
orig = my_event().MetaData({})
115+
remove_empty_metadata(orig.query_ast)
116+
117+
assert "MetaData" in ast.dump(orig.query_ast)
118+
119+
120+
def test_remove_metadata_no_change_2_levels():
121+
"Removing metadata should not alter original query"
122+
orig = my_event().MetaData({}).MetaData({})
123+
remove_empty_metadata(orig.query_ast)
124+
125+
assert "MetaData" in ast.dump(orig.query_ast)

tests/test_object_stream.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Test the object stream
22
import ast
33
import asyncio
4+
import logging
45
from typing import Any, Iterable, Optional, Tuple, TypeVar
56

67
import pytest
@@ -173,6 +174,7 @@ def test_query_metadata_dup(caplog):
173174

174175

175176
def test_query_metadata_dup_update(caplog):
177+
caplog.set_level(logging.INFO)
176178
r = (
177179
my_event()
178180
.QMetaData({"one": "two", "two": "three"})
@@ -189,6 +191,7 @@ def test_query_metadata_dup_update(caplog):
189191

190192

191193
def test_query_metadata_composable(caplog):
194+
caplog.set_level(logging.INFO)
192195
r_base = my_event().QMetaData({"one": "1"})
193196

194197
# Each of these is a different base and should not interfear.

0 commit comments

Comments
 (0)