Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/router_replay/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ actor_rollout_ref.actor.router_replay.mode="R3"
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
```

R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284 and SGLang implementation at https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051.
R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284 as well as bug fix at https://github.com/vllm-project/vllm/pull/33013 and SGLang implementation at https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051.
4 changes: 3 additions & 1 deletion examples/router_replay/run_qwen30_a3b_megatron_vllm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ NODES=1
# R2: enable routing replay
# R3: enable rollout routing replay
# If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True
# R3 example is based on vllm related pr https://github.com/vllm-project/vllm/pull/5322
# R3 example is based on vllm related pr:
# - https://github.com/vllm-project/vllm/pull/28284
# - https://github.com/vllm-project/vllm/pull/33013

ROUTING_REPLAY_MODE="R2"

Expand Down
8 changes: 8 additions & 0 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self):
self.recorded_topk_idx = None # For recording
self.router_replay_action = None # Router replay action for this layer
self.replay_backward_list = [] # List of tensors for backward pass replay
self.layer_number = None # Global layer index if available
RouterReplay.router_instances.append(self)

def set_target_indices(self, topk_indices: torch.Tensor):
Expand Down Expand Up @@ -336,6 +337,12 @@ def patched_tf_config_init(self, *args, **kwargs):
return

original_init = TopKRouter.__init__
original_set_layer_number = TopKRouter.set_layer_number

def patched_set_layer_number(self, layer_number: int):
original_set_layer_number(self, layer_number)
if self.router_replay is not None:
self.router_replay.layer_number = layer_number

# Step 3: Define the new __init__ method
def patched_init(self, *args, **kwargs):
Expand Down Expand Up @@ -374,4 +381,5 @@ def patched_preprocess(self, routing_map):
# Step 5: Apply the patches
TopKRouter.__init__ = patched_init
TopKRouter.routing = patched_routing
TopKRouter.set_layer_number = patched_set_layer_number
TopKRouter._router_replay_patched = True
15 changes: 14 additions & 1 deletion verl/utils/megatron/router_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,24 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N
layers_topk_idx_reshape = layers_topk_idx_rmpad_split.permute(0, 2, 1, 3).squeeze(
dim=0
) # layer_num, dynamic_bs_all, topk
num_layers_in_data = layers_topk_idx_reshape.shape[0]
use_global_layer_index = getattr(tf_config, "num_layers", None) == num_layers_in_data
local_rank_info = get_current_rank_layer_info(tf_config, vp_rank)
offset, _ = local_rank_info["start"], local_rank_info["end"]
router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank)
for i, router in enumerate(router_instances_list):
router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64))
layer_idx = None
if use_global_layer_index:
layer_number = getattr(router, "layer_number", None)
if layer_number is not None:
layer_idx = layer_number - 1
if layer_idx is None:
layer_idx = i + offset
if layer_idx < 0 or layer_idx >= num_layers_in_data:
raise ValueError(
f"router replay layer index {layer_idx} out of range for data with {num_layers_in_data} layers"
)
router.set_target_indices(layers_topk_idx_reshape[layer_idx].to(torch.int64))


def reorder_and_merge_vpp_layers(
Expand Down
Loading