Skip to content

Commit 2c469ca

Browse files
committed
refactor 1
1 parent 3835936 commit 2c469ca

File tree

1 file changed

+42
-57
lines changed

1 file changed

+42
-57
lines changed

torchtitan/models/moe/moe.py

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -239,69 +239,46 @@ def _debug_force_load_balance_routing(
239239
top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K]
240240
return selected_experts_indices, top_scores
241241

242-
def _node_limited_routing(
242+
def _get_node_limited_routing_scores(
243243
self,
244-
scores: torch.Tensor,
245-
expert_bias: torch.Tensor | None,
244+
scores_for_choice: torch.Tensor,
246245
) -> tuple[torch.Tensor, torch.Tensor]:
247-
"""Select top_k experts, optionally limiting to a subset of expert groups.
248-
249-
If num_expert_groups is set, applies node-limited routing:
250-
1. Select num_limited_groups groups based on group scores
251-
2. Select top_k experts only from those groups
252-
253-
If expert_bias is provided, it is added to scores for selection, but
254-
the returned top_scores are always from the original (unbiased) scores.
246+
"""Select num_limited_groups groups based on group scores,
247+
and set expert scores in non-selected groups as -inf
255248
256249
Args:
257-
scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts)
258-
expert_bias: Optional bias for load balancing, shape (num_experts,)
250+
scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts)
259251
260252
Returns:
261-
tuple of (selected_experts_indices, top_scores)
262-
- selected_experts_indices: shape (bs*slen, top_k)
263-
- top_scores: shape (bs*slen, top_k)
253+
scores_for_choice: shape (bs*slen, num_experts)
264254
"""
265-
scores_for_choice = scores if expert_bias is None else scores + expert_bias
266-
267-
# Apply node-limited routing mask if configured
268-
if self.num_expert_groups is not None:
269-
if self.num_limited_groups is None:
270-
raise ValueError(
271-
"num_limited_groups must be set when num_expert_groups is set"
272-
)
273-
if self.num_experts % self.num_expert_groups != 0:
274-
raise ValueError(
275-
f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})"
276-
)
277-
experts_per_group = self.num_experts // self.num_expert_groups
278-
if experts_per_group < 2:
279-
raise ValueError(
280-
f"experts_per_group ({experts_per_group}) must be >= 2"
281-
)
282-
scores_grouped = scores_for_choice.view(
283-
-1, self.num_expert_groups, experts_per_group
255+
if self.num_limited_groups is None:
256+
raise ValueError(
257+
"num_limited_groups must be set when num_expert_groups is set"
284258
)
285-
group_scores = scores_grouped.topk(2, dim=-1)[0].sum(dim=-1)
286-
group_idx = torch.topk(
287-
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
288-
)[1]
289-
group_mask = torch.ones_like(group_scores, dtype=torch.bool)
290-
group_mask.scatter_(1, group_idx, False) # False = selected groups (keep)
291-
# Mask out experts from non-selected groups
292-
scores_for_choice = scores_grouped.masked_fill(
293-
group_mask.unsqueeze(-1), float("-inf")
294-
).view(-1, self.num_experts)
295-
296-
selected_experts_indices = torch.topk(
297-
scores_for_choice, k=self.top_k, dim=-1, sorted=False
298-
)[1]
299-
300-
# NOTE: The expert_bias is only used for routing. The gating value
301-
# top_scores is still derived from the original scores.
302-
top_scores = scores.gather(dim=1, index=selected_experts_indices)
259+
if self.num_experts % self.num_expert_groups != 0:
260+
raise ValueError(
261+
f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})"
262+
)
263+
experts_per_group = self.num_experts // self.num_expert_groups
264+
if experts_per_group < 2:
265+
raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2")
266+
scores_grouped = scores_for_choice.view(
267+
-1, self.num_expert_groups, experts_per_group
268+
)
269+
top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1)
270+
group_scores = top2_scores_in_group.sum(dim=-1)
271+
_, group_idx = torch.topk(
272+
group_scores, k=self.num_limited_groups, dim=-1, sorted=False
273+
)
274+
group_mask = torch.ones_like(group_scores, dtype=torch.bool)
275+
group_mask.scatter_(1, group_idx, False) # False = selected groups (keep)
276+
# Mask out experts from non-selected groups
277+
scores_for_choice = scores_grouped.masked_fill(
278+
group_mask.unsqueeze(-1), float("-inf")
279+
).view(-1, self.num_experts)
303280

304-
return selected_experts_indices, top_scores
281+
return scores_for_choice
305282

306283
def forward(
307284
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
@@ -332,11 +309,19 @@ def forward(
332309
else:
333310
raise NotImplementedError(f"Unknown score function {self.score_func}")
334311

335-
# top scores shape (bs*slen, top_k)
336-
selected_experts_indices, top_scores = self._node_limited_routing(
337-
scores, expert_bias
312+
scores_for_choice = scores if expert_bias is None else scores + expert_bias
313+
# Apply node-limited routing if configured
314+
if self.num_expert_groups is not None:
315+
scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice)
316+
_, selected_experts_indices = torch.topk(
317+
scores_for_choice, k=self.top_k, dim=-1, sorted=False
338318
)
339319

320+
# top scores shape (bs*slen, top_k)
321+
# NOTE: The expert_bias is only used for routing. The gating value
322+
# top_scores is still derived from the original scores.
323+
top_scores = scores.gather(dim=1, index=selected_experts_indices)
324+
340325
# debug override: balanced round-robin routing
341326
if self._debug_force_load_balance:
342327
(

0 commit comments

Comments
 (0)