From ee2e5f323c54c2443ed5dd1a1cd412dbc36c58f6 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Fri, 5 Jul 2024 22:10:01 -0700 Subject: [PATCH] Log the gradient norm --- train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train.py b/train.py index 6c408e9..174a73c 100644 --- a/train.py +++ b/train.py @@ -226,6 +226,16 @@ def train( f"Passkey retrieval accuracy at {length} tokens: {accuracy:.2f}" ) + # Log gradient norm + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + wandb.log({"gradient_norm": total_norm}) + logger.info(f"Gradient norm: {total_norm:.4f}") + # Log metrics wandb.log( {