Skip to content

Commit

Permalink
Update comments in the train() method
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent c342992 commit 05ca736
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,34 +130,60 @@ def train(
epochs=10,
gradient_accumulation_steps=4,
):
"""
Train the LongRoPE model.
Args:
model (nn.Module): The LongRoPE model to train.
train_loader (DataLoader): DataLoader for training data.
val_loader (DataLoader): DataLoader for validation data.
optimizer (Optimizer): Optimizer for updating model parameters.
criterion (nn.Module): Loss function.
scheduler (LRScheduler): Learning rate scheduler.
epochs (int): Number of training epochs.
gradient_accumulation_steps (int): Number of steps to accumulate gradients.
Returns:
None
"""
# Initialize the gradient scaler for mixed precision training
scaler = GradScaler()

# Variables for early stopping
best_val_loss = float("inf")
patience = 0
max_patience = 3

for epoch in range(epochs):
model.train()
total_loss = 0

for i, (inputs, targets) in enumerate(train_loader):
# Move data to the appropriate device (CPU or GPU)
inputs, targets = (
inputs.to(accelerator.device),
targets.to(accelerator.device),
)

# Use mixed precision training
with autocast():
outputs = model(inputs)
loss = criterion(outputs.permute(0, 2, 1), targets)
# Normalize the loss to account for gradient accumulation
loss = loss / gradient_accumulation_steps

# Backpropagate and accumulate gradients
scaler.scale(loss).backward()

if (i + 1) % gradient_accumulation_steps == 0:
# Update weights and reset gradients
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

total_loss += loss.item()

# Calculate average training loss and perplexity
avg_train_loss = total_loss / len(train_loader)
train_perplexity = compute_perplexity(avg_train_loss)

Expand All @@ -174,11 +200,14 @@ def train(
loss = criterion(outputs.permute(0, 2, 1), targets)
val_loss += loss.item()

# Calculate average validation loss and perplexity
avg_val_loss = val_loss / len(val_loader)
val_perplexity = compute_perplexity(avg_val_loss)

# Update learning rate
scheduler.step()

# Log metrics
wandb.log(
{
"epoch": epoch,
Expand All @@ -190,6 +219,7 @@ def train(
}
)

# Print epoch results
print(
f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}"
Expand Down

0 comments on commit 05ca736

Please sign in to comment.