From 05ca7367f6e3f5cfc4551f491e07cbbf05280707 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Thu, 4 Jul 2024 22:21:50 -0700 Subject: [PATCH] Update comments in the train() method --- train.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/train.py b/train.py index 2770bb2..60b8bcb 100644 --- a/train.py +++ b/train.py @@ -130,7 +130,26 @@ def train( epochs=10, gradient_accumulation_steps=4, ): + """ + Train the LongRoPE model. + + Args: + model (nn.Module): The LongRoPE model to train. + train_loader (DataLoader): DataLoader for training data. + val_loader (DataLoader): DataLoader for validation data. + optimizer (Optimizer): Optimizer for updating model parameters. + criterion (nn.Module): Loss function. + scheduler (LRScheduler): Learning rate scheduler. + epochs (int): Number of training epochs. + gradient_accumulation_steps (int): Number of steps to accumulate gradients. + + Returns: + None + """ + # Initialize the gradient scaler for mixed precision training scaler = GradScaler() + + # Variables for early stopping best_val_loss = float("inf") patience = 0 max_patience = 3 @@ -138,26 +157,33 @@ def train( for epoch in range(epochs): model.train() total_loss = 0 + for i, (inputs, targets) in enumerate(train_loader): + # Move data to the appropriate device (CPU or GPU) inputs, targets = ( inputs.to(accelerator.device), targets.to(accelerator.device), ) + # Use mixed precision training with autocast(): outputs = model(inputs) loss = criterion(outputs.permute(0, 2, 1), targets) + # Normalize the loss to account for gradient accumulation loss = loss / gradient_accumulation_steps + # Backpropagate and accumulate gradients scaler.scale(loss).backward() if (i + 1) % gradient_accumulation_steps == 0: + # Update weights and reset gradients scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() + # Calculate average training loss and perplexity avg_train_loss = total_loss / len(train_loader) train_perplexity = compute_perplexity(avg_train_loss) @@ -174,11 +200,14 @@ def train( loss = criterion(outputs.permute(0, 2, 1), targets) val_loss += loss.item() + # Calculate average validation loss and perplexity avg_val_loss = val_loss / len(val_loader) val_perplexity = compute_perplexity(avg_val_loss) + # Update learning rate scheduler.step() + # Log metrics wandb.log( { "epoch": epoch, @@ -190,6 +219,7 @@ def train( } ) + # Print epoch results print( f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, " f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}"