Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contextual tokenizer #1415

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 37 additions & 54 deletions stanza/models/tokenization/model.py
Original file line number Diff line number Diff line change
@@ -20,13 +20,14 @@ def __init__(self, args, pretrain, hidden_dim, device=None):
self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True,
batch_first=True, num_layers=args['rnn_layers'])

# standard up and down projection a la transformers
self.ffnn = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim*4),
nn.ReLU(),
nn.Linear(hidden_dim*4, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

@@ -36,7 +37,8 @@ def device(self):

def forward(self, x):
# map the vocab to pretrain IDs
embs = self.embeddings(torch.tensor([[self.vocab[j.strip()] for j in i] for i in x],
token_ids = [[self.vocab[j.strip()] for j in i] for i in x]
embs = self.embeddings(torch.tensor(token_ids,
device=self.device))
net = self.emb_proj(embs)
net = self.lstm(net)[0]
@@ -147,70 +149,51 @@ def forward(self, x, feats, text):
else:
draft_preds = torch.cat([nontok, tok], 2).argmax(dim=2)
draft_preds = (draft_preds > 0)
# we add a prefix zero
# TODO inefficient / how to parallelize this?
front_pad = torch.tensor([-1]).to(draft_preds.device)
back_pad = torch.tensor([len(text[0])-1]).to(draft_preds.device)
token_locations = [torch.cat([front_pad, i.nonzero().squeeze(1).detach(), back_pad])
for i in draft_preds]

# both: batch x seq x [variable: text token count]
batch_tokens = [] # str tokens
batch_tokenid_locations = [] # id locations for the *end* of each str token
# corresponding to char token
for location,chars, toks in zip(token_locations, text, x):
# we append len(chars)-1 to append the last token which wouldn't
# necessearily have been captured by the splits; though in theory
# the model should put a token at the end of each sentence so this
# should be less of a problem

a,b = tee(location)
tokens = []
tokenid_locations = []
next(b) # because we want to start iterating on the NEXT id to create pairs
j = -1
for i,j in zip(a,b):
split = chars[i+1:j+1]
# if the entire unit is UNK, leave as UNK into the predictor
is_unk = ((toks[i+1:j+1]) == UNK_ID).all()
if set(split) == set([PAD]):
continue
tokenid_locations.append(j)

if not is_unk:
tokens.append("".join(split).replace(PAD, ""))
else:
tokens.append(UNK)

batch_tokens.append(tokens)
batch_tokenid_locations.append(tokenid_locations)
# these boolean indicies are *inclusive*, so predict it or not
# we need to split on the last token if we want to keep the
# final word
draft_preds[:,-1] = True

# both: batch x [variable: text token count]
extracted_tokens = []
partial = []
last = 0
last_batch = -1

nonzero = draft_preds.nonzero().cpu().tolist()
for i,j in nonzero:
if i != last_batch:
last_batch = i
last = 0
if i != 0:
extracted_tokens.append(partial)
partial = []

substring = text[i][last:j+1]
last = j+1

partial.append("".join(substring))
extracted_tokens.append(partial)

# dynamically pad the batch tokens to size
# why max 5? our
max_size = max(max([len(i) for i in batch_tokens]),
# why to at least a fix size? it must be wider
# than our kernel
max_size = max(max([len(i) for i in extracted_tokens]),
self.args["sentence_analyzer_kernel"])
batch_tokens_padded = []
batch_tokens_isntpad = []
for i in batch_tokens:
for i in extracted_tokens:
batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))])
batch_tokens_isntpad.append([True for _ in range(len(i))] +
[False for _ in range(max_size-len(i))])

##### TODO EVERYTHING BELOW THIS LINE IS UNTESTED #####
second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded)

# we only add scores for slots for which we have a possible word ending
# i.e. its not padding and its also not a middle of rough score's resulting
# words
# # we only add scores for slots for which we have a possible word ending
# # i.e. its not padding and its also not a middle of rough score's resulting
# # words
second_pass_chars_align = torch.zeros_like(sent0)
token_location_selectors = torch.tensor([[i,k] for i,j in
enumerate(batch_tokenid_locations)
for k in j])

second_pass_chars_align[
token_location_selectors[:,0],
token_location_selectors[:,1]
] = second_pass_scores[torch.tensor(batch_tokens_isntpad)]
second_pass_chars_align[draft_preds] = second_pass_scores[torch.tensor(batch_tokens_isntpad)]

sent0 += second_pass_chars_align