From 10f5e9413d8a1a706764b67d677b1b5a31f4c2fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 21 Feb 2023 15:51:29 +0100 Subject: [PATCH] Small `update_with_oracle_cut_size` fixes Fix an off-by-one in `TransitionModel.forward`, where we always did one move more than the maximum number of moves. This explosed another issue: when creating cut states, we skipped states where the (maximum number of) moves from that state only applied transitions that did not modify the buffer. Replace uses of `random.uniform` by `random.randrange`. --- spacy/ml/tb_framework.pyx | 2 +- spacy/pipeline/transition_parser.pyx | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index 9b2114900d3..2a3a5682380 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -338,9 +338,9 @@ def _forward_fallback( all_ids.append(ids) all_statevecs.append(statevecs) all_which.append(which) + n_moves += 1 if n_moves >= max_moves >= 1: break - n_moves += 1 def backprop_parser(d_states_d_scores): ids = ops.xp.vstack(all_ids) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2d2a3625287..7b49402a2b3 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -258,7 +258,7 @@ class Parser(TrainablePipe): # batch uniform length. Since we do not have a gold standard # sequence, we use the teacher's predictions as the gold # standard. - max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) + max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2) states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves) else: states = self.moves.init_batch(student_docs) @@ -425,7 +425,7 @@ class Parser(TrainablePipe): if max_moves >= 1: # Chop sequences into lengths of this many words, to make the # batch uniform length. - max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2)) + max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2) init_states, gold_states, _ = self._init_gold_batch( examples, max_length=max_moves @@ -729,9 +729,8 @@ class Parser(TrainablePipe): action.do(state.c, action.label) if state.is_final(): break - if moves.has_gold(eg, start_state.B(0), state.B(0)): - states.append(start_state) - golds.append(gold) + states.append(start_state) + golds.append(gold) if state.is_final(): break return states, golds, max_length