Skip to content

Commit

Permalink
Remove checkpointing functions in lieu of the huggingface accelerator…
Browse files Browse the repository at this point in the history
… library
  • Loading branch information
jshuadvd committed Jul 7, 2024
1 parent 35f4d07 commit 463678b
Showing 1 changed file with 0 additions and 21 deletions.
21 changes: 0 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,6 @@ def preprocess_data(data, tokenizer, max_length, overlap):
return sequences


def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss):
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_val_loss": best_val_loss,
}
torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pt")
logger.info(f"Checkpoint saved for epoch {epoch}")


def load_checkpoint(model, optimizer, scheduler, filename):
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
logger.info(f"Loaded checkpoint from {filename}")
return checkpoint["epoch"], checkpoint["best_val_loss"]


def compute_perplexity(loss):
return torch.exp(loss)

Expand Down

0 comments on commit 463678b

Please sign in to comment.