Skip to content

Commit

Permalink
Merge pull request #62 from iris-hep/feature_48_concat
Browse files Browse the repository at this point in the history
add `Concat` operator
  • Loading branch information
masonproffitt authored Oct 18, 2021
2 parents af83e4b + 46a086e commit 8a7e9d9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ All defined s-expressions are listed here, though this specification will be exp
- `predicate` must be a `lambda` with one argument
- Any: `(Any <source> <predicate>)`
- `predicate` must be a `lambda` with one argument
- Concat: `(Concat <first> <second>)`
- Zip: `(Zip <source>)`
- OrderBy: `(OrderBy <source> <key_selector>)`
- `key_selector` must be a `lambda` with one argument
Expand Down
9 changes: 9 additions & 0 deletions qastle/linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class Any(ast.AST):
_fields = ['source', 'predicate']


class Concat(ast.AST):
_fields = ['first', 'second']


class Zip(ast.AST):
_fields = ['source']

Expand Down Expand Up @@ -84,6 +88,7 @@ class Choose(ast.AST):
'Sum',
'All',
'Any',
'Concat',
'Zip',
'OrderBy',
'OrderByDescending',
Expand Down Expand Up @@ -193,6 +198,10 @@ def visit_Call(self, node):
raise SyntaxError('Any() call argument must be a lambda')
return Any(source=self.visit(source),
predicate=self.visit(args[0]))
elif function_name == 'Concat':
if len(args) != 1:
raise SyntaxError('Concat() call must have exactly one argument')
return Concat(first=self.visit(source), second=self.visit(args[0]))
elif function_name == 'Zip':
if len(args) != 0:
raise SyntaxError('Zip() call must have zero arguments')
Expand Down
12 changes: 11 additions & 1 deletion qastle/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .linq_util import (Where, Select, SelectMany, First, Last, ElementAt, Aggregate, Count, Max,
Min, Sum, All, Any, Zip, OrderBy, OrderByDescending, Choose)
Min, Sum, All, Any, Concat, Zip, OrderBy, OrderByDescending, Choose)
from .ast_util import wrap_ast, unwrap_ast

import lark
Expand Down Expand Up @@ -219,6 +219,11 @@ def visit_Any(self, node):
self.visit(node.source),
self.visit(node.predicate))

def visit_Concat(self, node):
return self.make_composite_node_string('Concat',
self.visit(node.first),
self.visit(node.second))

def visit_Zip(self, node):
return self.make_composite_node_string('Zip', self.visit(node.source))

Expand Down Expand Up @@ -499,6 +504,11 @@ def composite(self, children):
+ str(len(fields[1].args.args)))
return Any(source=fields[0], predicate=fields[1])

elif node_type == 'Concat':
if len(fields) != 2:
raise SyntaxError('Concat node must have two fields; found ' + str(len(fields)))
return Concat(first=fields[0], second=fields[1])

elif node_type == 'Zip':
if len(fields) != 1:
raise SyntaxError('Zip node must have one field; found ' + str(len(fields)))
Expand Down
7 changes: 7 additions & 0 deletions tests/test_ast_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ def test_Any():
'(Any data_source (lambda (list e) e))')


def test_Concat():
first_node = Concat(first=unwrap_ast(ast.parse('sequence1')),
second=unwrap_ast(ast.parse('sequence2')))
assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node),
'(Concat sequence1 sequence2)')


def test_Zip():
first_node = Zip(source=unwrap_ast(ast.parse('data_source')))
assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node),
Expand Down
21 changes: 21 additions & 0 deletions tests/test_linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,27 @@ def test_any_bad():
insert_linq_nodes(ast.parse('the_source.Any(None)'))


def test_concat():
initial_ast = ast.parse("sequence1.Concat(sequence2)")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Concat(first=unwrap_ast(ast.parse('sequence1')),
second=unwrap_ast(ast.parse('sequence2'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_concat_composite():
initial_ast = ast.parse("the_source.First().Concat(sequence)")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Concat(first=First(source=unwrap_ast(ast.parse('the_source'))),
second=unwrap_ast(ast.parse('sequence'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_concat_bad():
with pytest.raises(SyntaxError):
insert_linq_nodes(ast.parse('the_source.Concat()'))


def test_zip():
initial_ast = ast.parse("the_source.Zip()")
final_ast = insert_linq_nodes(initial_ast)
Expand Down

0 comments on commit 8a7e9d9

Please sign in to comment.