@@ -239,69 +239,46 @@ def _debug_force_load_balance_routing(
239239 top_scores = scores .gather (dim = 1 , index = selected_experts_indices ) # [N,K]
240240 return selected_experts_indices , top_scores
241241
242- def _node_limited_routing (
242+ def _get_node_limited_routing_scores (
243243 self ,
244- scores : torch .Tensor ,
245- expert_bias : torch .Tensor | None ,
244+ scores_for_choice : torch .Tensor ,
246245 ) -> tuple [torch .Tensor , torch .Tensor ]:
247- """Select top_k experts, optionally limiting to a subset of expert groups.
248-
249- If num_expert_groups is set, applies node-limited routing:
250- 1. Select num_limited_groups groups based on group scores
251- 2. Select top_k experts only from those groups
252-
253- If expert_bias is provided, it is added to scores for selection, but
254- the returned top_scores are always from the original (unbiased) scores.
246+ """Select num_limited_groups groups based on group scores,
247+ and set expert scores in non-selected groups as -inf
255248
256249 Args:
257- scores: Router scores after sigmoid or softmax, shape (bs*slen, num_experts)
258- expert_bias: Optional bias for load balancing, shape (num_experts,)
250+ scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts)
259251
260252 Returns:
261- tuple of (selected_experts_indices, top_scores)
262- - selected_experts_indices: shape (bs*slen, top_k)
263- - top_scores: shape (bs*slen, top_k)
253+ scores_for_choice: shape (bs*slen, num_experts)
264254 """
265- scores_for_choice = scores if expert_bias is None else scores + expert_bias
266-
267- # Apply node-limited routing mask if configured
268- if self .num_expert_groups is not None :
269- if self .num_limited_groups is None :
270- raise ValueError (
271- "num_limited_groups must be set when num_expert_groups is set"
272- )
273- if self .num_experts % self .num_expert_groups != 0 :
274- raise ValueError (
275- f"num_experts ({ self .num_experts } ) must be divisible by num_expert_groups ({ self .num_expert_groups } )"
276- )
277- experts_per_group = self .num_experts // self .num_expert_groups
278- if experts_per_group < 2 :
279- raise ValueError (
280- f"experts_per_group ({ experts_per_group } ) must be >= 2"
281- )
282- scores_grouped = scores_for_choice .view (
283- - 1 , self .num_expert_groups , experts_per_group
255+ if self .num_limited_groups is None :
256+ raise ValueError (
257+ "num_limited_groups must be set when num_expert_groups is set"
284258 )
285- group_scores = scores_grouped .topk (2 , dim = - 1 )[0 ].sum (dim = - 1 )
286- group_idx = torch .topk (
287- group_scores , k = self .num_limited_groups , dim = - 1 , sorted = False
288- )[1 ]
289- group_mask = torch .ones_like (group_scores , dtype = torch .bool )
290- group_mask .scatter_ (1 , group_idx , False ) # False = selected groups (keep)
291- # Mask out experts from non-selected groups
292- scores_for_choice = scores_grouped .masked_fill (
293- group_mask .unsqueeze (- 1 ), float ("-inf" )
294- ).view (- 1 , self .num_experts )
295-
296- selected_experts_indices = torch .topk (
297- scores_for_choice , k = self .top_k , dim = - 1 , sorted = False
298- )[1 ]
299-
300- # NOTE: The expert_bias is only used for routing. The gating value
301- # top_scores is still derived from the original scores.
302- top_scores = scores .gather (dim = 1 , index = selected_experts_indices )
259+ if self .num_experts % self .num_expert_groups != 0 :
260+ raise ValueError (
261+ f"num_experts ({ self .num_experts } ) must be divisible by num_expert_groups ({ self .num_expert_groups } )"
262+ )
263+ experts_per_group = self .num_experts // self .num_expert_groups
264+ if experts_per_group < 2 :
265+ raise ValueError (f"experts_per_group ({ experts_per_group } ) must be >= 2" )
266+ scores_grouped = scores_for_choice .view (
267+ - 1 , self .num_expert_groups , experts_per_group
268+ )
269+ top2_scores_in_group , _ = scores_grouped .topk (2 , dim = - 1 )
270+ group_scores = top2_scores_in_group .sum (dim = - 1 )
271+ _ , group_idx = torch .topk (
272+ group_scores , k = self .num_limited_groups , dim = - 1 , sorted = False
273+ )
274+ group_mask = torch .ones_like (group_scores , dtype = torch .bool )
275+ group_mask .scatter_ (1 , group_idx , False ) # False = selected groups (keep)
276+ # Mask out experts from non-selected groups
277+ scores_for_choice = scores_grouped .masked_fill (
278+ group_mask .unsqueeze (- 1 ), float ("-inf" )
279+ ).view (- 1 , self .num_experts )
303280
304- return selected_experts_indices , top_scores
281+ return scores_for_choice
305282
306283 def forward (
307284 self , x : torch .Tensor , expert_bias : torch .Tensor | None = None
@@ -332,11 +309,19 @@ def forward(
332309 else :
333310 raise NotImplementedError (f"Unknown score function { self .score_func } " )
334311
335- # top scores shape (bs*slen, top_k)
336- selected_experts_indices , top_scores = self ._node_limited_routing (
337- scores , expert_bias
312+ scores_for_choice = scores if expert_bias is None else scores + expert_bias
313+ # Apply node-limited routing if configured
314+ if self .num_expert_groups is not None :
315+ scores_for_choice = self ._get_node_limited_routing_scores (scores_for_choice )
316+ _ , selected_experts_indices = torch .topk (
317+ scores_for_choice , k = self .top_k , dim = - 1 , sorted = False
338318 )
339319
320+ # top scores shape (bs*slen, top_k)
321+ # NOTE: The expert_bias is only used for routing. The gating value
322+ # top_scores is still derived from the original scores.
323+ top_scores = scores .gather (dim = 1 , index = selected_experts_indices )
324+
340325 # debug override: balanced round-robin routing
341326 if self ._debug_force_load_balance :
342327 (
0 commit comments