Skip to content

Commit

Permalink
Update the train() method to evaluate the passkey retrieval and log t…
Browse files Browse the repository at this point in the history
…he results
  • Loading branch information
jshuadvd committed Jul 6, 2024
1 parent 3fc24d5 commit 6f20445
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ def train(
# Update learning rate
scheduler.step()

# Evaluate passkey retrieval at the end of each epoch
evaluate_passkey_retrieval(model, tokenizer, model.max_len)
# Evaluate passkey retrieval at the end of each epoch and log results
passkey_accuracies = evaluate_passkey_retrieval(model, tokenizer, model.max_len)
for length, accuracy in passkey_accuracies.items():
wandb.log({f"passkey_retrieval_{length}": accuracy})
logger.info(
f"Passkey retrieval accuracy at {length} tokens: {accuracy:.2f}"
)

# Log metrics
wandb.log(
Expand Down

0 comments on commit 6f20445

Please sign in to comment.