Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small update_with_oracle_cut_size fixes #12314

Draft
wants to merge 3 commits into
base: v5
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spacy/ml/tb_framework.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ def _forward_fallback(
all_ids.append(ids)
all_statevecs.append(statevecs)
all_which.append(which)
n_moves += 1
if n_moves >= max_moves >= 1:
break
n_moves += 1

def backprop_parser(d_states_d_scores):
ids = ops.xp.vstack(all_ids)
Expand Down
14 changes: 10 additions & 4 deletions spacy/pipeline/_parser_internals/arc_eager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t
from libcpp.vector cimport vector
import numpy
cimport numpy as np

from collections import defaultdict, Counter

Expand All @@ -16,6 +18,7 @@ from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
from ...errors import Errors
from .search cimport Beam
from .transition_system import OracleSequence

cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok')
Expand Down Expand Up @@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem):
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))

history = []
cost_matrix = []
debug_log = []
failed = False
while not state.is_final():
costs = numpy.zeros((self.n_moves,), dtype="f")
try:
self.set_costs(is_valid, costs, state.c, gold)
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
except ValueError:
failed = True
break
min_cost = min(costs[i] for i in range(self.n_moves))
cost_matrix.append(costs)
min_cost = costs.min()
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= min_cost:
action = self.c[i]
Expand Down Expand Up @@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem):
print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024)
return history
return OracleSequence(history, numpy.array(cost_matrix))
34 changes: 28 additions & 6 deletions spacy/pipeline/_parser_internals/transition_system.pyx
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# cython: infer_types=True
from __future__ import print_function
from typing import List, Optional
from cymem.cymem cimport Pool
from libc.stdlib cimport calloc, free
from libcpp.vector cimport vector
import numpy
cimport numpy as np

from collections import Counter
import srsly
Expand All @@ -25,6 +28,22 @@ class OracleError(Exception):
pass


class OracleSequence:
actions: List[int]
cost_matrix: numpy.ndarray

def __init__(self, actions: List[int], cost_matrix: numpy.ndarray):
self.actions = actions
self.cost_matrix = cost_matrix

__slots = ["actions", "cost_matrix"]

def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool:
if end is None:
end = self.cost_matrix.shape[0]
return numpy.count_nonzero(self.cost_matrix[begin:end])


cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateC* st = new StateC(<const TokenC*>tokens, length)
return <void*>st
Expand Down Expand Up @@ -87,10 +106,10 @@ cdef class TransitionSystem:

def get_oracle_sequence(self, Example example, _debug=False):
if not self.has_gold(example):
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
states, golds, _ = self.init_gold_batch([example])
if not states:
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
state = states[0]
gold = golds[0]
if _debug:
Expand All @@ -100,17 +119,20 @@ cdef class TransitionSystem:

def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
if state.is_final():
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))

history = []
cost_matrix = []
debug_log = []
while not state.is_final():
self.set_costs(is_valid, costs, state.c, gold)
costs = numpy.zeros((self.n_moves,), dtype="f")
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
cost_matrix.append(costs)
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0:
action = self.c[i]
Expand Down Expand Up @@ -147,7 +169,7 @@ cdef class TransitionSystem:
)))
print("\n".join(debug_log))
raise ValueError(Errors.E024)
return history
return OracleSequence(history, numpy.array(cost_matrix))

def apply_transition(self, StateClass state, name):
if not self.is_valid(state, name):
Expand Down
19 changes: 11 additions & 8 deletions spacy/pipeline/transition_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class Parser(TrainablePipe):
# batch uniform length. Since we do not have a gold standard
# sequence, we use the teacher's predictions as the gold
# standard.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
states = self._init_batch_from_teacher(teacher_pipe, student_docs, max_moves)
else:
states = self.moves.init_batch(student_docs)
Expand Down Expand Up @@ -425,7 +425,7 @@ class Parser(TrainablePipe):
if max_moves >= 1:
# Chop sequences into lengths of this many words, to make the
# batch uniform length.
max_moves = int(random.uniform(max(max_moves // 2, 1), max_moves * 2))
max_moves = random.randrange(max(max_moves // 2, 1), max_moves * 2)
init_states, gold_states, _ = self._init_gold_batch(
examples,
max_length=max_moves
Expand Down Expand Up @@ -715,21 +715,24 @@ class Parser(TrainablePipe):
states.append(state)
golds.append(gold)
else:
oracle_actions = moves.get_oracle_sequence_from_state(
oracle_seq = moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
to_cut.append((eg, state, gold, oracle_seq))
if not to_cut:
return states, golds, 0
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
for eg, state, gold, oracle_seq in to_cut:
for i in range(0, len(oracle_seq.actions), max_length):
start_state = state.copy()
for clas in oracle_actions[i:i+max_length]:
for clas in oracle_seq.actions[i:i+max_length]:
action = moves.c[clas]
action.do(state.c, action.label)
if state.is_final():
break
if moves.has_gold(eg, start_state.B(0), state.B(0)):
# If all actions along the history are zero-cost actions, there
# is nothing to learn from this state in max_length stepss, so
# we skip it.
if oracle_seq.has_cost(i, i+max_length):
states.append(start_state)
golds.append(gold)
if state.is_final():
Expand Down
6 changes: 3 additions & 3 deletions spacy/tests/parser/test_arc_eager_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_get_oracle_actions():
example = Example.from_dict(
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
)
parser.moves.get_oracle_sequence(example)
parser.moves.get_oracle_sequence(example).actions


def test_oracle_dev_sentence(vocab, arc_eager):
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager):
arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions

Expand Down Expand Up @@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager):
reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
)
example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions
12 changes: 7 additions & 5 deletions spacy/tests/parser/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_issue4313():

def test_get_oracle_moves(tsys, doc, entity_annots):
example = Example.from_dict(doc, {"entities": entity_annots})
act_classes = tsys.get_oracle_sequence(example, _debug=False)
act_classes = tsys.get_oracle_sequence(example, _debug=False).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]

Expand All @@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"
Expand All @@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"
Expand All @@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 1, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"
Expand Down Expand Up @@ -540,11 +540,13 @@ def test_block_ner():
assert [token.ent_type_ for token in doc] == expected_types


def test_overfitting_IO():
@pytest.mark.parametrize("max_moves", [0, 1, 5, 100])
def test_overfitting_IO(max_moves):
fix_random_seed(1)
# Simple test to try and quickly overfit the NER component
nlp = English()
ner = nlp.add_pipe("ner", config={"model": {}})
ner.cfg["update_with_oracle_cut_size"] = max_moves
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
Expand Down
Loading