Skip to content

Commit

Permalink
Version 1.0 Minor code corrections and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex authored Sep 22, 2023
1 parent cb13613 commit 962d124
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tegridy-tools/nanoGPT/nanoGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def compute_accuracy(self, logits, labels):

return acc

def forward(self, idx, targets=None):
def forward(self, idx, targets=None, compute_acc=False):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
Expand All @@ -229,7 +229,10 @@ def forward(self, idx, targets=None):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
acc = self.compute_accuracy(logits, targets)
if compute_acc:
acc = self.compute_accuracy(logits, targets)
else:
acc = torch.LongTensor([0])
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.config.ignore_idx)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
Expand Down Expand Up @@ -293,7 +296,7 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
return mfu

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, verbose=True, return_prime=True):
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, min_stop_token=-1, return_prime=True, verbose=True):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Expand Down Expand Up @@ -325,6 +328,20 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, verbose=Tru
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)

# stop token code
if min_stop_token >= 0:
for sa in idx_next:
if sa >= min_stop_token:
stop = True
break
else:
stop = False
if stop:
if verbose:
print('Model called the end of sequence at:', s, '/', max_new_tokens)
break

# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

Expand Down

0 comments on commit 962d124

Please sign in to comment.