diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index 9b2114900d3..2a3a5682380 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -338,9 +338,9 @@ def _forward_fallback( all_ids.append(ids) all_statevecs.append(statevecs) all_which.append(which) + n_moves += 1 if n_moves >= max_moves >= 1: break - n_moves += 1 def backprop_parser(d_states_d_scores): ids = ops.xp.vstack(all_ids) diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 9c358475a70..68015bb175f 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -2,6 +2,8 @@ from cymem.cymem cimport Pool, Address from libc.stdint cimport int32_t from libcpp.vector cimport vector +import numpy +cimport numpy as np from collections import defaultdict, Counter @@ -16,6 +18,7 @@ from .stateclass cimport StateClass from ._state cimport StateC, ArcC from ...errors import Errors from .search cimport Beam +from .transition_system import OracleSequence cdef weight_t MIN_SCORE = -90000 cdef attr_t SUBTOK_LABEL = hash_string('subtok') @@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem): cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 - costs = mem.alloc(self.n_moves, sizeof(float)) + cdef np.ndarray costs is_valid = mem.alloc(self.n_moves, sizeof(int)) history = [] + cost_matrix = [] debug_log = [] failed = False while not state.is_final(): + costs = numpy.zeros((self.n_moves,), dtype="f") try: - self.set_costs(is_valid, costs, state.c, gold) + self.set_costs(is_valid, costs.data, state.c, gold) except ValueError: failed = True break - min_cost = min(costs[i] for i in range(self.n_moves)) + cost_matrix.append(costs) + min_cost = costs.min() for i in range(self.n_moves): if is_valid[i] and costs[i] <= min_cost: action = self.c[i] @@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem): print("Stack", [example.x[i] for i in state.stack]) print("Buffer", [example.x[i] for i in state.queue]) raise ValueError(Errors.E024) - return history + return OracleSequence(history, numpy.array(cost_matrix)) diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 89f9e8ae820..c1850a54245 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -1,8 +1,11 @@ # cython: infer_types=True from __future__ import print_function +from typing import List, Optional from cymem.cymem cimport Pool from libc.stdlib cimport calloc, free from libcpp.vector cimport vector +import numpy +cimport numpy as np from collections import Counter import srsly @@ -25,6 +28,22 @@ class OracleError(Exception): pass +class OracleSequence: + actions: List[int] + cost_matrix: numpy.ndarray + + def __init__(self, actions: List[int], cost_matrix: numpy.ndarray): + self.actions = actions + self.cost_matrix = cost_matrix + + __slots = ["actions", "cost_matrix"] + + def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool: + if end is None: + end = self.cost_matrix.shape[0] + return numpy.count_nonzero(self.cost_matrix[begin:end]) + + cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef StateC* st = new StateC(tokens, length) return st @@ -87,10 +106,10 @@ cdef class TransitionSystem: def get_oracle_sequence(self, Example example, _debug=False): if not self.has_gold(example): - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) states, golds, _ = self.init_gold_batch([example]) if not states: - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) state = states[0] gold = golds[0] if _debug: @@ -100,17 +119,20 @@ cdef class TransitionSystem: def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): if state.is_final(): - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 - costs = mem.alloc(self.n_moves, sizeof(float)) + cdef np.ndarray costs is_valid = mem.alloc(self.n_moves, sizeof(int)) history = [] + cost_matrix = [] debug_log = [] while not state.is_final(): - self.set_costs(is_valid, costs, state.c, gold) + costs = numpy.zeros((self.n_moves,), dtype="f") + self.set_costs(is_valid, costs.data, state.c, gold) + cost_matrix.append(costs) for i in range(self.n_moves): if is_valid[i] and costs[i] <= 0: action = self.c[i] @@ -147,7 +169,7 @@ cdef class TransitionSystem: ))) print("\n".join(debug_log)) raise ValueError(Errors.E024) - return history + return OracleSequence(history, numpy.array(cost_matrix)) def apply_transition(self, StateClass state, name): if not self.is_valid(state, name): diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2d2a3625287..06a5c218382 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -258,7 +258,7 @@ class Parser(TrainablePipe): # batch uniform length. Since we do not have a gold standard # sequence, we use the teacher's predictions as the gold # standard. - max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) + max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2) states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) @@ -425,7 +425,7 @@ class Parser(TrainablePipe): if max_moves >= 1: # Chop sequences into lengths of this many words, to make the # batch uniform length. - max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) + max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2) init_states, gold_states, _ = self._init_gold_batch( examples, max_length=max_moves @@ -715,21 +715,24 @@ class Parser(TrainablePipe): states.append(state) golds.append(gold) else: - oracle_actions = moves.get_oracle_sequence_from_state( + oracle_seq = moves.get_oracle_sequence_from_state( state.copy(), gold) - to_cut.append((eg, state, gold, oracle_actions)) + to_cut.append((eg, state, gold, oracle_seq)) if not to_cut: return states, golds, 0 cdef int clas - for eg, state, gold, oracle_actions in to_cut: - for i in range(0, len(oracle_actions), max_length): + for eg, state, gold, oracle_seq in to_cut: + for i in range(0, len(oracle_seq.actions), max_length): start_state = state.copy() - for clas in oracle_actions[i:i+max_length]: + for clas in oracle_seq.actions[i:i+max_length]: action = moves.c[clas] action.do(state.c, action.label) if state.is_final(): break - if moves.has_gold(eg, start_state.B(0), state.B(0)): + # If all actions along the history are zero-cost actions, there + # is nothing to learn from this state in max_length stepss, so + # we skip it. + if oracle_seq.has_cost(i, i+max_length): states.append(start_state) golds.append(gold) if state.is_final(): diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index bb226f9c557..1dd01bc5232 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -168,7 +168,7 @@ def test_get_oracle_actions(): example = Example.from_dict( doc, {"words": words, "tags": tags, "heads": heads, "deps": deps} ) - parser.moves.get_oracle_sequence(example) + parser.moves.get_oracle_sequence(example).actions def test_oracle_dev_sentence(vocab, arc_eager): @@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager): arc_eager.add_action(3, dep) # Right doc = Doc(Vocab(), words=gold_words) example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) - ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] assert ae_oracle_actions == expected_transitions @@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager): reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"] ) example = Example(predicted=predicted, reference=reference) - ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] assert ae_oracle_actions diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 62b8f97047c..ce1bdfce0bc 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -231,7 +231,7 @@ def test_issue4313(): def test_get_oracle_moves(tsys, doc, entity_annots): example = Example.from_dict(doc, {"entities": entity_annots}) - act_classes = tsys.get_oracle_sequence(example, _debug=False) + act_classes = tsys.get_oracle_sequence(example, _debug=False).actions names = [tsys.get_class_name(act) for act in act_classes] assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"] @@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 2, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O" @@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 2, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O" @@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 1, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O" @@ -540,11 +540,13 @@ def test_block_ner(): assert [token.ent_type_ for token in doc] == expected_types -def test_overfitting_IO(): +@pytest.mark.parametrize("max_moves", [0, 1, 5, 100]) +def test_overfitting_IO(max_moves): fix_random_seed(1) # Simple test to try and quickly overfit the NER component nlp = English() ner = nlp.add_pipe("ner", config={"model": {}}) + ner.cfg["update_with_oracle_cut_size"] = max_moves train_examples = [] for text, annotations in TRAIN_DATA: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))