From 463678be2ab978379251c5d143ef289095d4c539 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sat, 6 Jul 2024 21:54:29 -0700 Subject: [PATCH] Remove checkpointing functions in lieu of the huggingface accelerator library --- train.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/train.py b/train.py index 0a0271f..44ac730 100644 --- a/train.py +++ b/train.py @@ -125,27 +125,6 @@ def preprocess_data(data, tokenizer, max_length, overlap): return sequences -def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss): - checkpoint = { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "best_val_loss": best_val_loss, - } - torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pt") - logger.info(f"Checkpoint saved for epoch {epoch}") - - -def load_checkpoint(model, optimizer, scheduler, filename): - checkpoint = torch.load(filename) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) - logger.info(f"Loaded checkpoint from {filename}") - return checkpoint["epoch"], checkpoint["best_val_loss"] - - def compute_perplexity(loss): return torch.exp(loss)