diff --git a/tegridy-tools/nanoGPT/nanoGPT.py b/tegridy-tools/nanoGPT/nanoGPT.py index 3143952..e22ae21 100644 --- a/tegridy-tools/nanoGPT/nanoGPT.py +++ b/tegridy-tools/nanoGPT/nanoGPT.py @@ -212,7 +212,7 @@ def compute_accuracy(self, logits, labels): return acc - def forward(self, idx, targets=None): + def forward(self, idx, targets=None, compute_acc=False): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" @@ -229,7 +229,10 @@ def forward(self, idx, targets=None): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) - acc = self.compute_accuracy(logits, targets) + if compute_acc: + acc = self.compute_accuracy(logits, targets) + else: + acc = torch.LongTensor([0]) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.config.ignore_idx) else: # inference-time mini-optimization: only forward the lm_head on the very last position @@ -293,7 +296,7 @@ def estimate_mfu(self, fwdbwd_per_iter, dt): return mfu @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, verbose=True, return_prime=True): + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, min_stop_token=-1, return_prime=True, verbose=True): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. @@ -325,6 +328,20 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, verbose=Tru probs = F.softmax(logits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) + + # stop token code + if min_stop_token >= 0: + for sa in idx_next: + if sa >= min_stop_token: + stop = True + break + else: + stop = False + if stop: + if verbose: + print('Model called the end of sequence at:', s, '/', max_new_tokens) + break + # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1)