Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions func_adl_xAOD/common/ast_to_cpp_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _extract_column_names(names_ast: ast.AST) -> List[str]:
return names


def find_fill_scope(a: ast.AST) -> ast.AST:
def find_fill_scope(a: ast.AST) -> ast.expr:
"""Returns the ast of the item where we would want to put the fill
statement for building an ntuple. Does this by walking the tree and finding
the lowest mainline term (e.g. event level scoping).
Expand All @@ -169,14 +169,21 @@ def fill_node(self) -> ast.AST:
return self._node

def visit_Call(self, node: ast.Call):
if isinstance(node.func, ast.Name):
if self._node is None and node.func.id in ["Where", "EventDataset", "SelectMany"]:
if self._node is None:
if isinstance(node.func, ast.Name):
if node.func.id in [
"Where",
"EventDataset",
"SelectMany",
]:
self._node = node
if isinstance(node.func, cpp_ast.CPPCodeValue):
self._node = node
self.generic_visit(node)
self.generic_visit(node)

finder = find_where()
finder.visit(a)
return finder.fill_node
return cast(ast.expr, finder.fill_node)


class query_ast_visitor(FuncADLNodeVisitor, ABC):
Expand Down Expand Up @@ -302,7 +309,7 @@ def get_as_ROOT(self, node: ast.AST) -> rh.cpp_ttree_rep:
col_values = values.value_dict.values()
col_names = ast.List(elts=list(values.value_dict.keys()))
s_tuple = crep.cpp_tuple(tuple(col_values), values.scope())
tuple_sequence = crep.cpp_sequence(s_tuple, r.iterator_value(), r.scope()) # type: ignore
tuple_sequence = crep.cpp_sequence(s_tuple, r.iterator_value(), r.scope(), node) # type: ignore
crep.set_rep(node, tuple_sequence)
assert isinstance(node, ast.expr) # making sure that types are correct.
ast_ttree = function_call(
Expand Down Expand Up @@ -383,7 +390,7 @@ def resolve_id(self, id: str) -> Optional[ast.AST]:
return None

def make_sequence_from_collection(
self, rep: crep.cpp_collection
self, rep: crep.cpp_collection, node: ast.expr
) -> crep.cpp_sequence:
"""
Take a collection and produce a sequence. Eventually this should likely be some sort of
Expand All @@ -404,10 +411,10 @@ def make_sequence_from_collection(

# For a new sequence like this the sequence and iterator value are the same
return crep.cpp_sequence(
iterator_value, iterator_value, self._gc.current_scope()
iterator_value, iterator_value, self._gc.current_scope(), node
)

def as_sequence(self, generation_ast: ast.AST):
def as_sequence(self, generation_ast: ast.expr):
r"""
We will convert the generation_ast into a sequence if we can. If we can't, that indicates
a likely programming error by this library or by the user.
Expand All @@ -434,7 +441,7 @@ def as_sequence(self, generation_ast: ast.AST):

# If this is a collection, then we need to turn it into a sequence.
if isinstance(rep, crep.cpp_collection):
r = self.make_sequence_from_collection(rep)
r = self.make_sequence_from_collection(rep, generation_ast)
self._gc.set_rep(rep, r)
return r

Expand Down Expand Up @@ -1139,7 +1146,8 @@ def fill_collection_levels(
inner = seq.sequence_value()
scope = seq.scope()
if isinstance(inner, crep.cpp_sequence):
scope = seq.iterator_value().scope()
scope = self.as_sequence(find_fill_scope(seq.node())).scope()
# scope = seq.iterator_value().scope()
storage = crep.cpp_variable(
unique_name("ntuple"), scope, cpp_type=inner.cpp_type()
)
Expand Down Expand Up @@ -1177,7 +1185,7 @@ def call_ResultTTree(self, node: ast.Call, args: List[ast.AST]):
"""This AST means we are taking an iterable and converting it to a ROOT file."""
# Unpack the variables.
assert len(args) == 4
source = args[0]
source = cast(ast.expr, args[0])
column_names = _extract_column_names(args[1])
tree_name = ast.literal_eval(args[2])
assert isinstance(tree_name, str)
Expand Down Expand Up @@ -1267,7 +1275,7 @@ def call_Select(self, node: ast.Call, args: List[ast.arg]):
assert isinstance(selection, ast.Lambda)

# Make sure we are in a loop
seq = self.as_sequence(source)
seq = self.as_sequence(source) # type: ignore

# Simulate this as a "call"
c = ast.Call(
Expand All @@ -1279,7 +1287,7 @@ def call_Select(self, node: ast.Call, args: List[ast.arg]):

# We need to build a new sequence.
rep = crep.cpp_sequence(
new_sequence_value, seq.iterator_value(), self._gc.current_scope()
new_sequence_value, seq.iterator_value(), self._gc.current_scope(), node
)

crep.set_rep(node, rep)
Expand All @@ -1291,7 +1299,7 @@ def call_SelectMany(self, node: ast.AST, args: List[ast.AST]):
loop over that collection.
"""
assert len(args) == 2
source = args[0]
source = cast(ast.expr, args[0])
selection = args[1]
assert isinstance(selection, ast.Lambda)

Expand All @@ -1316,11 +1324,11 @@ def call_SelectMany(self, node: ast.AST, args: List[ast.AST]):
crep.set_rep(node, seq)
return seq

def call_Where(self, node: ast.AST, args: List[ast.AST]):
def call_Where(self, node: ast.expr, args: List[ast.AST]):
"Apply a filtering to the current loop."

assert len(args) == 2
source = args[0]
source = cast(ast.expr, args[0])
filter = args[1]
assert isinstance(filter, ast.Lambda)

Expand Down Expand Up @@ -1350,7 +1358,7 @@ def call_Where(self, node: ast.AST, args: List[ast.AST]):
crep.set_rep(
node,
crep.cpp_sequence(
new_sequence_var, seq.iterator_value(), self._gc.current_scope()
new_sequence_var, seq.iterator_value(), self._gc.current_scope(), node
),
)

Expand Down Expand Up @@ -1412,7 +1420,7 @@ def call_Range(self, node: ast.Call, args: List[ast.AST]):

self._gc.add_statement(statement.arbitrary_statement(self.get_rep(c).as_cpp())) # type: ignore

seq = self.make_sequence_from_collection(vector_value)
seq = self.make_sequence_from_collection(vector_value, node)
crep.set_rep(node, seq)
return seq

Expand All @@ -1421,7 +1429,7 @@ def call_First(self, node: ast.AST, args: List[ast.AST]) -> Any:

# Unpack the source here
assert len(args) == 1
source = args[0]
source = cast(ast.expr, args[0])

# Make sure we are in a loop.
seq = self.as_sequence(source)
Expand Down
7 changes: 7 additions & 0 deletions func_adl_xAOD/common/cpp_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def __init__(
sequence_value: Union[cpp_value, cpp_sequence],
iterator_value: cpp_value,
scope: Union[gc_scope_top_level, gc_scope],
node: ast.expr,
):
"""
Create a sequence
Expand All @@ -320,12 +321,14 @@ def __init__(
inside the if statement of the Where, while the iterator will be at the scope of
the original declaration. If the `sequence_value` is an actual value, it will have
the same scope.
node: The AST node associated with this sequence.
"""
cpp_rep_base.__init__(self)
self._sequence = sequence_value
self._iterator = iterator_value
self._type: Optional[ctyp.collection] = None
self._scope = scope
self._node = node

def sequence_value(self):
return self._sequence
Expand All @@ -345,6 +348,10 @@ def scope(self) -> Union[gc_scope, gc_scope_top_level]:
"Return scope where this sequence was created/valid"
return self._scope

def node(self) -> ast.expr:
"Node in ast tree where this sequence is defined."
return self._node


def set_rep(node: ast.AST, value: cpp_rep_base, scope: Optional[Any] = None):
"""
Expand Down
4 changes: 3 additions & 1 deletion func_adl_xAOD/common/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ def write_cpp_files(self, ast: ast.AST, output_path: Path) -> ExecutionInfo:
iterator = crep.cpp_variable(
"bogus-do-not-use", top_level_scope(), cpp_type=None
)
crep.set_rep(file, crep.cpp_sequence(iterator, iterator, top_level_scope()))
crep.set_rep(
file, crep.cpp_sequence(iterator, iterator, top_level_scope(), file)
)

# Visit the AST to generate the code structure and find out what the
# result is going to be.
Expand Down
6 changes: 3 additions & 3 deletions tests/atlas/r22_xaod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def exe_from_qastle(q: str):

file = find_EventDataset(a)
iterator = cpp_variable("bogus-do-not-use", top_level_scope(), cpp_type=None)
set_rep(file, cpp_sequence(iterator, iterator, top_level_scope()))
set_rep(file, cpp_sequence(iterator, iterator, top_level_scope(), file))

# Use the dummy executor to process this, and return it.
exe = atlas_xaodr22_dummy_executor()
Expand All @@ -83,7 +83,7 @@ def load_root_as_pandas(file: Path) -> pd.DataFrame:
assert isinstance(file, Path)
assert file.exists()

with uproot.open(file) as input:
with uproot.open(file) as input: # type: ignore
return input["atlas_xaod_tree"].arrays(library="pd") # type: ignore


Expand All @@ -100,7 +100,7 @@ def load_root_as_awkward(file: Path) -> ak.Array:
assert isinstance(file, Path)
assert file.exists()

with uproot.open(file) as input:
with uproot.open(file) as input: # type: ignore
return input["atlas_xaod_tree"].arrays() # type: ignore


Expand Down
6 changes: 4 additions & 2 deletions tests/atlas/xaod/test_first_last.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Code to do the testing starts here.
import func_adl_xAOD.common.cpp_types as ctyp
from tests.utils.locators import find_line_numbers_with, find_line_with, find_next_closing_bracket, find_open_blocks # type: ignore
from tests.utils.locators import find_line_numbers_with, find_line_with, find_next_closing_bracket # type: ignore
from tests.utils.general import get_lines_of_code, print_lines # type: ignore
from tests.atlas.xaod.utils import atlas_xaod_dataset # type: ignore
import re
Expand Down Expand Up @@ -90,7 +90,9 @@ def test_First_Of_Select_is_not_array():
l_first_test = l_first_tests[1]

# Ensure the indent columns in lines[l_fill] and lines[l_first_test] are the same
assert lines[l_fill].startswith(" " * (len(lines[l_first_test]) - len(lines[l_first_test].lstrip())))
assert lines[l_fill].startswith(
" " * (len(lines[l_first_test]) - len(lines[l_first_test].lstrip()))
)


def test_First_Of_Select_After_Where_is_in_right_place():
Expand Down
3 changes: 3 additions & 0 deletions tests/atlas/xaod/test_query_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_as_root_as_dict():
dict_obj, # type: ignore
crep.cpp_value("i", gc_scope_top_level(), ctyp.terminal("int")),
gc_scope_top_level(),
node,
)
node.rep = sequence # type: ignore
as_root = q.get_as_ROOT(node)
Expand All @@ -181,6 +182,7 @@ def test_as_root_as_single_column():
value_obj,
crep.cpp_value("i", gc_scope_top_level(), ctyp.terminal("int")),
gc_scope_top_level(),
node,
)
node.rep = sequence # type: ignore
as_root = q.get_as_ROOT(node)
Expand All @@ -200,6 +202,7 @@ def test_as_root_as_tuple():
value_obj, # type: ignore
crep.cpp_value("i", gc_scope_top_level(), ctyp.terminal("int")),
gc_scope_top_level(),
node,
)
node.rep = sequence # type: ignore
as_root = q.get_as_ROOT(node)
Expand Down
24 changes: 23 additions & 1 deletion tests/atlas/xaod/test_xaod_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def test_where_at_top_level_First_and_count():
print_lines(lines)

# Make sure we are grabbing the jet container and the fill at the same indent level.
i_pt_test = find_line_with(">1001.0", lines)
i_fill = find_line_with("->Fill()", lines)
assert 4 == len(lines[i_fill]) - len(lines[i_fill].lstrip())

Expand Down Expand Up @@ -438,6 +437,29 @@ def test_per_jet_item_with_where():
assert "Fill()" in lines[l_jet_pt + 1]


def test_where_in_sub_select():
"We have an if statement buried in a loop - make sure push_back is done right"
r = (
atlas_xaod_dataset()
.Select(
lambda e: [
[t.pt for t in e.Tracks("hi")] for j in e.Jets("hi") if j.pt() > 10
]
)
.value()
)

lines = get_lines_of_code(r)
print_lines(lines)

# Make sure we are grabbing the jet container and the fill at the same indent level.
i_if = find_line_with("pt()>10", lines)
i_if_indent = len(lines[i_if]) - len(lines[i_if].lstrip())
i_fill = find_line_with(".push_back(ntuple", lines)
i_fill_indent = len(lines[i_fill]) - len(lines[i_fill].lstrip())
assert i_if_indent < i_fill_indent


def test_and_clause_in_where():
# The following statement should be a straight sequence, not an array.
r = (
Expand Down
2 changes: 1 addition & 1 deletion tests/atlas/xaod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def exe_from_qastle(q: str):

file = find_EventDataset(a)
iterator = cpp_variable("bogus-do-not-use", top_level_scope(), cpp_type=None)
set_rep(file, cpp_sequence(iterator, iterator, top_level_scope()))
set_rep(file, cpp_sequence(iterator, iterator, top_level_scope(), file))

# Use the dummy executor to process this, and return it.
exe = atlas_xaod_dummy_executor()
Expand Down
Loading
Loading