Skip to content

Commit

Permalink
Add a method which recursively looks through a tree for two nodes in …
Browse files Browse the repository at this point in the history
…one tree which combine to form one larger node in the other tree. This will potentially be a good source of augmentation for contrastive learning
  • Loading branch information
AngledLuffa committed Jan 28, 2025
1 parent 203a60c commit bb2b3f9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
76 changes: 76 additions & 0 deletions stanza/models/constituency/parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,79 @@ def _mark_spans(self, start_index):
start_index = child.end_index

self.end_index = start_index

@staticmethod
def single_missing_node_errors(original, predicted):
"""
Given the correct tree and the predicted tree, returns a list of single node missing / added errors.
The return format will be:
(outer label, missing/added label, left child label, right child label, T/F whether the node is supposed to be there)
Interestingly the nodes themselves aren't needed, at least not yet
Operates by recursively going through both trees at the same time.
If the span lengths of a pair of nodes are the same, it recurses on that pair of nodes.
Otherwise, if two consecutive spans from one tree add up to one span of the other tree, this is a candidate.
This candidate is accepted if the other tree's larger span is a node with exactly two children,
matching the labels of the original tree's children, with the spans the same as well.
Note that this does not guarantee the internal structure of those spans is the same.
"""
original.mark_spans()
predicted.mark_spans()

def check_missing_error(separate_tree, combined_tree, separate_idx, combined_idx, should_nest):
if separate_tree.label != combined_tree.label:
return None
if len(combined_tree.children) != 2:
return False
if (separate_tree.children[separate_idx].start_index == combined_tree.children[combined_idx].children[0].start_index and
separate_tree.children[separate_idx].end_index == combined_tree.children[combined_idx].children[0].end_index and
separate_tree.children[separate_idx].label == combined_tree.children[combined_idx].children[0].label and
separate_tree.children[separate_idx+1].start_index == combined_tree.children[combined_idx].children[-1].start_index and
separate_tree.children[separate_idx+1].end_index == combined_tree.children[combined_idx].children[-1].end_index and
separate_tree.children[separate_idx+1].label == combined_tree.children[combined_idx].children[-1].label):
return (combined_tree.label, combined_tree.children[combined_idx].label,
combined_tree.children[combined_idx].children[0].label, combined_tree.children[combined_idx].children[-1].label,
should_nest)

errors = []
def missing_node_helper(original, predicted):
#print("Checking: %s %s" % (original, predicted))
if original.is_preterminal() or predicted.is_preterminal():
return

orig_idx = 0
pred_idx = 0
while orig_idx < len(original.children) and pred_idx < len(predicted.children):
#print(original.children[orig_idx].start_index, original.children[orig_idx].end_index,
# predicted.children[pred_idx].start_index, predicted.children[pred_idx].end_index)
if original.children[orig_idx].start_index < predicted.children[pred_idx].start_index:
orig_idx += 1
continue
if original.children[orig_idx].start_index > predicted.children[pred_idx].start_index:
pred_idx += 1
continue
# the start indices are the same
# first thing to check: if the end indices are the same, can recurse
if original.children[orig_idx].end_index == predicted.children[pred_idx].end_index:
missing_node_helper(original.children[orig_idx], predicted.children[pred_idx])
orig_idx += 1
pred_idx += 1
continue
# in this case, one of the end indices is lower. there could potentially
# be an attachment error in that case
attachment = None
if original.children[orig_idx].end_index < predicted.children[pred_idx].end_index:
if orig_idx + 1 < len(original.children) and original.children[orig_idx+1].end_index == predicted.children[pred_idx].end_index:
attachment = check_missing_error(original, predicted, orig_idx, pred_idx, False)
elif original.children[orig_idx].end_index > predicted.children[pred_idx].end_index:
if pred_idx + 1 < len(predicted.children) and predicted.children[pred_idx+1].end_index == original.children[orig_idx].end_index:
attachment = check_missing_error(predicted, original, pred_idx, orig_idx, True)
orig_idx += 1
pred_idx += 1
if attachment:
errors.append(attachment)

missing_node_helper(original, predicted)
return errors

22 changes: 22 additions & 0 deletions stanza/tests/constituency/test_parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,25 @@ def test_mark_spans():
for idx, pt in enumerate(tree.yield_preterminals()):
assert pt.start_index == idx
assert pt.end_index == idx + 1

def read_single_tree(text):
trees = tree_reader.read_trees(text)
assert len(trees) == 1
tree = trees[0]
return tree


def test_missing_node_errors():
correct_attach = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB eat) (NP (NP (NN spaghetti)) (PP (IN with) (NP (NNS meatballs))))))))))"
wrong_attach = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB eat) (NP (NN spaghetti)) (PP (IN with) (NP (NNS meatballs)))))))))"
correct_attach = read_single_tree(correct_attach)
wrong_attach = read_single_tree(wrong_attach)

#print("{:P}".format(correct_attach))
#print("{:P}".format(wrong_attach))

errors = Tree.single_missing_node_errors(correct_attach, wrong_attach)
assert errors == [('VP', 'NP', 'NP', 'PP', True)]

errors = Tree.single_missing_node_errors(wrong_attach, correct_attach)
assert errors == [('VP', 'NP', 'NP', 'PP', False)]

0 comments on commit bb2b3f9

Please sign in to comment.