diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index e03bd5c98e..8f679e409d 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -154,7 +154,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor): deterministic_mode=self.config.deterministic_mode, ) - if self.training: + if self.training and torch.is_grad_enabled(): # Apply load balancing loss scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs) @@ -213,7 +213,7 @@ def apply_z_loss(self, logits): Returns: torch.Tensor: The logits after applying the z-loss. """ - if self.config.moe_z_loss_coeff is not None and self.training: + if self.config.moe_z_loss_coeff is not None and self.training and torch.is_grad_enabled(): moe_z_loss_coeff = ( self.config.moe_z_loss_coeff / parallel_state.get_tensor_and_context_parallel_world_size()