Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: prevent double accumulation of load balancing loss and z-loss wi… #1331

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor):
deterministic_mode=self.config.deterministic_mode,
)

if self.training:
if self.training and torch.is_grad_enabled():
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
Expand Down Expand Up @@ -213,7 +213,7 @@ def apply_z_loss(self, logits):
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None and self.training:
if self.config.moe_z_loss_coeff is not None and self.training and torch.is_grad_enabled():
moe_z_loss_coeff = (
self.config.moe_z_loss_coeff
/ parallel_state.get_tensor_and_context_parallel_world_size()
Expand Down