Skip to content

Commit

Permalink
Update the train() function. The update provides more comprehensive t…
Browse files Browse the repository at this point in the history
…racking of the models performance.The improvements make the training process more robust, efficient, and easier to monitor and debug.

Key differences and improvements:

Gradient Accumulation:

The updated function introduces gradient accumulation (gradient_accumulation_steps).
This allows for effectively larger batch sizes without increasing memory usage.
The loss is divided by gradient_accumulation_steps and gradients are accumulated over multiple forward passes before updating the model.

Mixed Precision Training:

The updated function uses autocast and GradScaler for mixed precision training.
This can speed up training and reduce memory usage, especially beneficial for large models.

Learning Rate Scheduling:

A learning rate scheduler is now used (scheduler.step()), which can help in better convergence.
Improved Loss Tracking:
The updated function keeps track of total loss over an epoch, allowing for a more accurate average loss calculation.

Perplexity Calculation:
The updated function calculates perplexity, which is a common metric for language models.

Logging:
The updated function uses wandb for logging various metrics, which allows for better tracking and visualization of the training process.

Checkpointing:
The updated function saves the model state after each epoch, allowing for training resumption if interrupted.

Early Stopping:
The updated function implements early stopping, which can prevent overfitting and save computation time.

Best Model Saving:
The updated function saves the best model based on validation loss, ensuring you keep the best performing model.

Removed Input Size Check:
The original function had a check for input size exceeding model.rope.max_len, which is removed in the updated version.
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent fb52462 commit c342992
Showing 1 changed file with 66 additions and 18 deletions.
84 changes: 66 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,30 +120,46 @@ def compute_perplexity(loss):
return torch.exp(loss)


def train(model, train_loader, val_loader, optimizer, criterion, epochs=10):
"""Training loop for the model."""
model.train()
def train(
model,
train_loader,
val_loader,
optimizer,
criterion,
scheduler,
epochs=10,
gradient_accumulation_steps=4,
):
scaler = GradScaler()
best_val_loss = float("inf")
patience = 0
max_patience = 3

for epoch in range(epochs):
for inputs, targets in train_loader:
model.train()
total_loss = 0
for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = (
inputs.to(accelerator.device),
targets.to(accelerator.device),
)

print(f"Input shape: {inputs.shape}")
print(f"Target shape: {targets.shape}")
with autocast():
outputs = model(inputs)
loss = criterion(outputs.permute(0, 2, 1), targets)
loss = loss / gradient_accumulation_steps

scaler.scale(loss).backward()

if inputs.size(1) > model.rope.max_len:
print(
f"Warning: Batch with input size {inputs.size(1)} exceeds the maximum length of {model.rope.max_len}."
)
continue # Skip this batch
if (i + 1) % gradient_accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs.permute(0, 2, 1), targets)
accelerator.backward(loss)
optimizer.step()
total_loss += loss.item()

avg_train_loss = total_loss / len(train_loader)
train_perplexity = compute_perplexity(avg_train_loss)

# Validation step
model.eval()
Expand All @@ -157,10 +173,42 @@ def train(model, train_loader, val_loader, optimizer, criterion, epochs=10):
outputs = model(inputs)
loss = criterion(outputs.permute(0, 2, 1), targets)
val_loss += loss.item()

avg_val_loss = val_loss / len(val_loader)
val_perplexity = compute_perplexity(avg_val_loss)

scheduler.step()

wandb.log(
{
"epoch": epoch,
"train_loss": avg_train_loss,
"train_perplexity": train_perplexity,
"val_loss": avg_val_loss,
"val_perplexity": val_perplexity,
"learning_rate": scheduler.get_last_lr()[0],
}
)

print(
f"Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss / len(val_loader)}"
f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}"
)
model.train()

# Save checkpoint
accelerator.save_state(f"checkpoint_epoch_{epoch}.pt")

# Early stopping
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience = 0
# Save best model
accelerator.save_state("best_model.pt")
else:
patience += 1
if patience >= max_patience:
print("Early stopping triggered")
break


# %%
Expand Down

0 comments on commit c342992

Please sign in to comment.