Skip to content

Commit

Permalink
Update the train method to use the passkey test
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent 7f2f11c commit b601dcf
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def train(
optimizer,
criterion,
scheduler,
tokenizer,
epochs=10,
gradient_accumulation_steps=4,
):
Expand All @@ -142,6 +143,7 @@ def train(
optimizer (Optimizer): Optimizer for updating model parameters.
criterion (nn.Module): Loss function.
scheduler (LRScheduler): Learning rate scheduler.
tokenizer: Tokenizer for encoding/decoding text.
epochs (int): Number of training epochs.
gradient_accumulation_steps (int): Number of steps to accumulate gradients.
Expand Down Expand Up @@ -209,6 +211,9 @@ def train(
# Update learning rate
scheduler.step()

# Evaluate passkey retrieval at the end of each epoch
evaluate_passkey_retrieval(model, tokenizer, model.max_len)

# Log metrics
wandb.log(
{
Expand Down

0 comments on commit b601dcf

Please sign in to comment.