diff --git a/train.py b/train.py index 3ca61b9..2770bb2 100644 --- a/train.py +++ b/train.py @@ -120,30 +120,46 @@ def compute_perplexity(loss): return torch.exp(loss) -def train(model, train_loader, val_loader, optimizer, criterion, epochs=10): - """Training loop for the model.""" - model.train() +def train( + model, + train_loader, + val_loader, + optimizer, + criterion, + scheduler, + epochs=10, + gradient_accumulation_steps=4, +): + scaler = GradScaler() + best_val_loss = float("inf") + patience = 0 + max_patience = 3 + for epoch in range(epochs): - for inputs, targets in train_loader: + model.train() + total_loss = 0 + for i, (inputs, targets) in enumerate(train_loader): inputs, targets = ( inputs.to(accelerator.device), targets.to(accelerator.device), ) - print(f"Input shape: {inputs.shape}") - print(f"Target shape: {targets.shape}") + with autocast(): + outputs = model(inputs) + loss = criterion(outputs.permute(0, 2, 1), targets) + loss = loss / gradient_accumulation_steps + + scaler.scale(loss).backward() - if inputs.size(1) > model.rope.max_len: - print( - f"Warning: Batch with input size {inputs.size(1)} exceeds the maximum length of {model.rope.max_len}." - ) - continue # Skip this batch + if (i + 1) % gradient_accumulation_steps == 0: + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs.permute(0, 2, 1), targets) - accelerator.backward(loss) - optimizer.step() + total_loss += loss.item() + + avg_train_loss = total_loss / len(train_loader) + train_perplexity = compute_perplexity(avg_train_loss) # Validation step model.eval() @@ -157,10 +173,42 @@ def train(model, train_loader, val_loader, optimizer, criterion, epochs=10): outputs = model(inputs) loss = criterion(outputs.permute(0, 2, 1), targets) val_loss += loss.item() + + avg_val_loss = val_loss / len(val_loader) + val_perplexity = compute_perplexity(avg_val_loss) + + scheduler.step() + + wandb.log( + { + "epoch": epoch, + "train_loss": avg_train_loss, + "train_perplexity": train_perplexity, + "val_loss": avg_val_loss, + "val_perplexity": val_perplexity, + "learning_rate": scheduler.get_last_lr()[0], + } + ) + print( - f"Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss / len(val_loader)}" + f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, " + f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}" ) - model.train() + + # Save checkpoint + accelerator.save_state(f"checkpoint_epoch_{epoch}.pt") + + # Early stopping + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + patience = 0 + # Save best model + accelerator.save_state("best_model.pt") + else: + patience += 1 + if patience >= max_patience: + print("Early stopping triggered") + break # %%