From db035d65edeade72cf4528cb6e5aa4390b5ebaad Mon Sep 17 00:00:00 2001 From: thuwzt Date: Fri, 20 Dec 2024 03:58:01 +0000 Subject: [PATCH] fix: prevent double accumulation of load balancing loss and z-loss with activation checkpointing --- megatron/core/transformer/moe/router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index e03bd5c98e..8f679e409d 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -154,7 +154,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor): deterministic_mode=self.config.deterministic_mode, ) - if self.training: + if self.training and torch.is_grad_enabled(): # Apply load balancing loss scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs) @@ -213,7 +213,7 @@ def apply_z_loss(self, logits): Returns: torch.Tensor: The logits after applying the z-loss. """ - if self.config.moe_z_loss_coeff is not None and self.training: + if self.config.moe_z_loss_coeff is not None and self.training and torch.is_grad_enabled(): moe_z_loss_coeff = ( self.config.moe_z_loss_coeff / parallel_state.get_tensor_and_context_parallel_world_size()