Skip to content

Commit

Permalink
Add a simple validation step after short context recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 10, 2024
1 parent 118c02e commit 7f2c9ef
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ def main():
tokenizer=tokenizer,
)

# Add a simple validation step after short context recovery
model.eval()
with torch.no_grad():
val_loss = sum(
criterion(model(inputs), targets).item()
for inputs, targets in val_loader
) / len(val_loader)
logger.info(f"Validation loss after short context recovery: {val_loss:.4f}")
wandb.log({"short_context_val_loss": val_loss})

# Finish logging and close the Weights & Biases run
wandb.finish()

Expand Down

0 comments on commit 7f2c9ef

Please sign in to comment.