Skip to content

Commit 962ce03

Browse files
author
Amit Beka
committed
core: add iteration to :class:Node
Iteration can be in DFS/BFS order, with/out duplicates, and on Nodes/Edges. Signed-off-by: Amit Beka <[email protected]>
1 parent 56caf1d commit 962ce03

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

core.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,49 @@ def equals(self, other, *, recursive=True, ordered=False):
523523
return False
524524
return True
525525

526+
def iter(self, obj="nodes", method="dfs", duplicates=False, key=None):
527+
"""Iterates the :class:Node objects in the subtree of self.
528+
529+
Args:
530+
obj: yield Node objects (use value "nodes", default) or Edge
531+
objects (use values "edges")
532+
method: do breadth-first iteration (use value "bfs") or depth-dirst
533+
iteration (value "dfs", default).
534+
duplicates: If True, may return the same object twice if it is
535+
encountered twice, because of the DAG structure which isn't
536+
necessarily a tree. If it is False, all objects will be yielded
537+
only the first time they are encountered. Defaults to False.
538+
key: boolean function that filters the iterable items. key function
539+
takes one argument (the item) and returns True if it should be
540+
returned to the user. If an item isn't returned, its subtree
541+
is still iterated. Defaults to None (returns all items).
542+
543+
Yields:
544+
a :class:Node or :class:Edge object according to the iteration
545+
parameters.
546+
547+
"""
548+
if method not in ("dfs", "bfs"):
549+
raise ValueError("method can be either 'dfs' or 'bfs'")
550+
if obj not in ("nodes", "edges"):
551+
raise ValueError("obj can be either 'nodes' or 'edges'")
552+
processed = set()
553+
if obj == 'nodes':
554+
waiting = [self]
555+
else:
556+
waiting = self._outgoing[:]
557+
while len(waiting):
558+
curr = waiting.pop(0)
559+
if key is None or key(curr):
560+
yield curr
561+
processed.add(curr)
562+
to_add = curr.children if obj == 'nodes' else list(curr.child)
563+
to_add = [x for x in to_add if duplicates or x not in processed]
564+
if method == "bfs":
565+
waiting.extend(to_add)
566+
else:
567+
waiting = to_add + waiting
568+
526569

527570
class Layer:
528571
"""Group of similar :class:Node objects in UCCA annotation graph.

tests/test_ucca_ut.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,28 @@ def test_copying(self):
204204
p2 = p1.copy([l0id])
205205
self.assertTrue(p1.layer(l0id).equals(p2.layer(l0id)))
206206

207+
def test_iteration(self):
208+
p = self._create_basic_passage()
209+
l1, l2 = p.layer('1'), p.layer('2')
210+
node11, node12, node13 = l1.all
211+
node22, node21 = l2.all
212+
213+
self.assertSequenceEqual(list(node11.iter()), [node11])
214+
self.assertSequenceEqual(list(node11.iter(obj='edges')), [])
215+
self.assertSequenceEqual(list(node13.iter(key=lambda x: x.tag != '3')),
216+
[])
217+
self.assertSequenceEqual(list(node12.iter()), [node12, node13, node11])
218+
self.assertSequenceEqual(list(x.ID for x in node12.iter(obj='edges')),
219+
['1.2->1.3', '1.2->1.1'])
220+
self.assertSequenceEqual(list(node21.iter(duplicates=True)),
221+
[node21, node11, node12, node13, node11])
222+
self.assertSequenceEqual(list(node21.iter()),
223+
[node21, node11, node12, node13])
224+
self.assertSequenceEqual(list(node22.iter(method='bfs',
225+
duplicates=True)),
226+
[node22, node11, node12, node13, node13,
227+
node11])
228+
207229

208230
class Layer0Tests(unittest.TestCase):
209231
"""Tests module layer0 functionality."""

0 commit comments

Comments
 (0)