Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
da13813
Move padding
JacobHelwig Jan 17, 2026
0157dca
Import and doc string
JacobHelwig Jan 17, 2026
843eb35
Max response len
JacobHelwig Jan 17, 2026
1650b3f
Prompt lens
JacobHelwig Jan 17, 2026
07510d2
Tests
JacobHelwig Jan 15, 2026
0ccefe2
Multiple 3D tensors test
JacobHelwig Jan 15, 2026
50967f6
Merge tests
JacobHelwig Jan 15, 2026
9af25b4
init
JacobHelwig Jan 12, 2026
d352d0e
Init debug script
JacobHelwig Jan 12, 2026
22eab0a
Plan
JacobHelwig Jan 13, 2026
fb61894
Add top-k log probs
JacobHelwig Jan 14, 2026
89f1590
Stage-wise top-k
JacobHelwig Jan 15, 2026
c3f980e
Distillation cfg
JacobHelwig Jan 16, 2026
d14002a
Distillation losses
JacobHelwig Jan 16, 2026
b01422d
Re-factor distillation
JacobHelwig Jan 16, 2026
b10f879
RM unused
JacobHelwig Jan 17, 2026
204a7e6
Working: Add full JSD/KL, rm FSDP cfg, routing by distillation loss type
JacobHelwig Jan 17, 2026
13d5dca
rm example
JacobHelwig Jan 17, 2026
d3a2e10
dp distillation cfg
JacobHelwig Jan 17, 2026
1fff408
Fix clamping
JacobHelwig Jan 17, 2026
5bff8be
Enable ref w distillation
JacobHelwig Jan 18, 2026
237e277
Teacher model cfg
JacobHelwig Jan 18, 2026
1b4e9ac
Legacy distillation
JacobHelwig Jan 18, 2026
1efb975
Distillation cfg to actor_rollout_ref
JacobHelwig Jan 18, 2026
f681201
Clamping and pass distillation cfg instead of actor
JacobHelwig Jan 18, 2026
9093570
Ruff
JacobHelwig Jan 18, 2026
dec2233
Distillation training script
JacobHelwig Jan 18, 2026
900d0b7
Update train script
JacobHelwig Jan 18, 2026
e556c56
Decouple distillation and ref configs
JacobHelwig Jan 18, 2026
6fa8506
Distillation cfg validation
JacobHelwig Jan 18, 2026
fb1ee0e
Loss settings in cfg
JacobHelwig Jan 18, 2026
28f9258
Use estimator
JacobHelwig Jan 18, 2026
50388a5
Distillation_config->distillation
JacobHelwig Jan 18, 2026
787c268
Distillatoin loss signatures
JacobHelwig Jan 18, 2026
2326e98
Running w no rm pad
JacobHelwig Jan 18, 2026
673f2fa
Ulysses working
JacobHelwig Jan 18, 2026
52592f9
SP in script
JacobHelwig Jan 18, 2026
efe4377
Log eps
JacobHelwig Jan 18, 2026
d0b20c3
Add distillation tests
JacobHelwig Jan 18, 2026
ce1329a
Reduce tests
JacobHelwig Jan 18, 2026
843509b
Doc strings
JacobHelwig Jan 18, 2026
759320f
Ruff
JacobHelwig Jan 18, 2026
3652bd6
Not implemented
JacobHelwig Jan 18, 2026
6b297ad
Clamp
JacobHelwig Jan 19, 2026
6ea8b2b
Divergence note
JacobHelwig Jan 19, 2026
ba3aa4c
take abs and distillation loss inputs
JacobHelwig Jan 19, 2026
2930308
Take abs in name
JacobHelwig Jan 19, 2026
46761bb
loss fix
JacobHelwig Jan 19, 2026
7294716
null
JacobHelwig Jan 19, 2026
e4bb2c5
Take abs fix
JacobHelwig Jan 19, 2026
59bd0f0
RM blank lines
JacobHelwig Jan 19, 2026
5d0ff5b
Long line and generate cfg
JacobHelwig Jan 19, 2026
ccbb86d
rl dataset
JacobHelwig Jan 19, 2026
8ee18e4
rm TODO
JacobHelwig Jan 19, 2026
e2ee7e5
CI fixes
JacobHelwig Jan 19, 2026
7fe1381
PC
JacobHelwig Jan 19, 2026
e28e562
None cfg
JacobHelwig Jan 19, 2026
82b608a
Distillation
JacobHelwig Jan 19, 2026
3729565
cfg to test
JacobHelwig Jan 19, 2026
5197059
None config
JacobHelwig Jan 19, 2026
926f4c7
Dist cfg
JacobHelwig Jan 19, 2026
7bd50ee
Distillation enabled
JacobHelwig Jan 19, 2026
118f058
Dist cfg test eng
JacobHelwig Jan 19, 2026
cf015a9
Update generated cfg
JacobHelwig Jan 19, 2026
c57bfdb
Entropy bonus
JacobHelwig Jan 19, 2026
627fc25
Types
JacobHelwig Jan 19, 2026
af68a78
Ruff
JacobHelwig Jan 26, 2026
802ea53
Generated PPO trainer
JacobHelwig Jan 26, 2026
74f728e
RM take abs
JacobHelwig Jan 26, 2026
85958e0
Only teacher topk in utils
JacobHelwig Jan 26, 2026
846a836
Re-factor stages
JacobHelwig Jan 26, 2026
378ce7f
FSDP losses
JacobHelwig Jan 27, 2026
58e1e28
Update training config
JacobHelwig Jan 27, 2026
16493c1
Ruff
JacobHelwig Jan 27, 2026
c170898
FSDP utils
JacobHelwig Jan 27, 2026
54f782f
Fix teacher logprobs for top-k
JacobHelwig Jan 27, 2026
b67ea55
Doc string
JacobHelwig Jan 27, 2026
de16be6
Distllation loss range
JacobHelwig Jan 27, 2026
4c5fdfa
Mask bug fix
JacobHelwig Jan 27, 2026
4e4f07a
Disable clamp
JacobHelwig Jan 27, 2026
1fbf2ec
Clamp distillation loss to 10
JacobHelwig Jan 27, 2026
13bc855
Clamp probs
JacobHelwig Jan 27, 2026
7a2d848
log prob clamping
JacobHelwig Jan 28, 2026
d0cad55
Clipping defaults
JacobHelwig Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env bash
eval "$(conda shell.bash hook)"
conda activate verl
export PATH=$CONDA_PREFIX/bin:$PATH
export NCCL_P2P_DISABLE=1
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=6,7
export DATA_PATH=$PWD/../verlData
export HF_HOME=$DATA_PATH
export VLLM_CACHE_DIR=$DATA_PATH/vllm_cache

set -xeuo pipefail

############################ Quick Config ############################

ROLLOUT_NAME="vllm" # sglang or vllm

FAMILY="Qwen"
STUDENT_MODEL=Qwen2.5-0.5B
TEACHER_MODEL=Qwen2.5-3B-Instruct

# DISTILLATION_LOSS_MODE="k3"
DISTILLATION_LOSS_MODE="forward_kl_topk"

DISTILLATION_LOSS_MAX_CLAMP=null
DISTILLATION_LOG_PROB_MIN_CLAMP=-10.0

PROJECT_NAME='verl_on_policy_distillation_example_gsm8k'
EXP_NAME="${FAMILY}/student-${STUDENT_MODEL}/teacher-${TEACHER_MODEL}/loss-${DISTILLATION_LOSS_MODE}-maxclamp-${DISTILLATION_LOSS_MAX_CLAMP}-logprobminclamp-${DISTILLATION_LOG_PROB_MIN_CLAMP}"

MAX_PROMPT=256
MAX_RESPONSE_LENGTH=512
TRAIN_PROMPT_BSZ=128
STUDENT_MICRO_BATCH_SIZE_PER_GPU=2
STUDENT_MAX_TOKEN_LEN_PER_GPU=$(( STUDENT_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))
TEACHER_MICRO_BATCH_SIZE_PER_GPU=2
TEACHER_MAX_TOKEN_LEN_PER_GPU=$(( TEACHER_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))

WORLD_SIZE=2
SP_SIZE=1

############################ Paths ############################

gsm8k_train_path=$DATA_PATH/gsm8k/train.parquet
gsm8k_test_path=$DATA_PATH/gsm8k/test.parquet

TRAIN_FILES="['$gsm8k_train_path']"
TEST_FILES="['$gsm8k_test_path']"

############################ Parameter Groups ############################

DATA=(
data.train_files="$TRAIN_FILES"
data.val_files="$TEST_FILES"
data.max_prompt_length=$MAX_PROMPT
data.max_response_length=$MAX_RESPONSE_LENGTH
data.train_batch_size=$TRAIN_PROMPT_BSZ
data.filter_overlong_prompts=True
data.truncation='error'
data.shuffle=False
)

MODEL=(
actor_rollout_ref.model.path="${FAMILY}/${STUDENT_MODEL}"
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.model.use_remove_padding=True
)

DISTILLATION=(
actor_rollout_ref.distillation.enabled=True
actor_rollout_ref.distillation.loss_mode=$DISTILLATION_LOSS_MODE
actor_rollout_ref.distillation.jsd_beta=0.5
actor_rollout_ref.distillation.topk=64
actor_rollout_ref.distillation.loss_max_clamp=$DISTILLATION_LOSS_MAX_CLAMP
actor_rollout_ref.distillation.log_prob_min_clamp=$DISTILLATION_LOG_PROB_MIN_CLAMP
actor_rollout_ref.distillation.log_prob_use_dynamic_bsz=True
actor_rollout_ref.distillation.log_prob_micro_batch_size_per_gpu=$TEACHER_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.distillation.log_prob_max_token_len_per_gpu=$TEACHER_MAX_TOKEN_LEN_PER_GPU
actor_rollout_ref.distillation.fsdp_config.param_offload=True
actor_rollout_ref.distillation.teacher_model.path="${FAMILY}/${TEACHER_MODEL}"
actor_rollout_ref.distillation.teacher_model.use_remove_padding=True
actor_rollout_ref.distillation.ulysses_sequence_parallel_size=$SP_SIZE
)

ACTOR=(
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_PROMPT_BSZ
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
actor_rollout_ref.actor.use_dynamic_bsz=True
actor_rollout_ref.actor.fsdp_config.param_offload=True
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$SP_SIZE
)

ROLLOUT=(
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True
actor_rollout_ref.rollout.tensor_model_parallel_size=1
actor_rollout_ref.rollout.name=$ROLLOUT_NAME
actor_rollout_ref.rollout.gpu_memory_utilization=0.3
actor_rollout_ref.rollout.n=1
)

ALGORITHM=(
algorithm.adv_estimator=grpo
algorithm.use_kl_in_reward=False
)

TRAINER=(
trainer.logger='["console","wandb"]'
trainer.project_name=$PROJECT_NAME
trainer.experiment_name=$EXP_NAME
trainer.n_gpus_per_node=$WORLD_SIZE
trainer.nnodes=1
trainer.save_freq=200
trainer.test_freq=5
trainer.total_epochs=15
trainer.val_before_train=True
trainer.use_legacy_worker_impl=disable
trainer.resume_mode=disable
)



############################ Launch ############################

python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name='ppo_trainer.yaml' \
"${DATA[@]}" \
"${ALGORITHM[@]}" \
"${MODEL[@]}" \
"${DISTILLATION[@]}" \
"${ROLLOUT[@]}" \
"${ACTOR[@]}" \
"${TRAINER[@]}" \
"$@"
12 changes: 11 additions & 1 deletion tests/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from verl.workers.config import (
ActorConfig,
CriticConfig,
DistillationConfig,
FSDPEngineConfig,
FSDPOptimizerConfig,
HFModelConfig,
Expand Down Expand Up @@ -100,12 +101,15 @@ def create_training_config(model_type, strategy, device_count, model):
else:
raise NotImplementedError(f"strategy {strategy} is not supported")

distillation_config = DistillationConfig(strategy=strategy, rollout_n=-1, ppo_micro_batch_size_per_gpu=-1)

config = TrainingWorkerConfig(
model_type=model_type,
model_config=model_config,
engine_config=engine_config,
optimizer_config=optimizer_config,
checkpoint_config=None,
distillation_config=distillation_config,
)
return config

Expand Down Expand Up @@ -204,8 +208,11 @@ def test_actor_engine(strategy):
# construct actor config
actor_config = ActorConfig(strategy=strategy, rollout_n=1, ppo_micro_batch_size_per_gpu=-1)

# construct distillation config
distillation_config = DistillationConfig(strategy=strategy, rollout_n=-1, ppo_micro_batch_size_per_gpu=-1)

# set ppo loss
ppo_loss_ = partial(ppo_loss, config=actor_config)
ppo_loss_ = partial(ppo_loss, config=actor_config, distillation_config=distillation_config)
wg.set_loss_fn(ppo_loss_)

# update again
Expand Down Expand Up @@ -395,6 +402,8 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod

checkpoint_config = CheckpointConfig()

distillation_config = DistillationConfig(strategy=strategy, rollout_n=-1, ppo_micro_batch_size_per_gpu=-1)

# build model engine
engine: BaseEngine = EngineRegistry.new(
model_type="language_model",
Expand All @@ -403,6 +412,7 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod
engine_config=engine_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
distillation_config=distillation_config,
)

engine.initialize()
Expand Down
109 changes: 109 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,115 @@ actor_rollout_ref:
speculative_num_draft_tokens: 4
method: mtp
num_speculative_tokens: 1
distillation:
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
strategy: ${actor_rollout_ref.actor.strategy}
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
log_prob_micro_batch_size: null
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
enable: false
all_ranks: false
ranks: []
save_path: ${oc.select:global_profiler.save_path,null}
tool_config:
nsys:
_target_: verl.utils.profiler.config.NsightToolConfig
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
npu:
_target_: verl.utils.profiler.config.NPUToolConfig
contents: []
level: level0
analysis: true
discrete: false
torch:
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
contents: []
discrete: false
torch_memory:
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
router_replay:
_target_: verl.workers.config.RouterReplayConfig
mode: disabled
record_file: null
replay_file: null
teacher_model:
_target_: verl.workers.config.HFModelConfig
path: ~/models/deepseek-llm-7b-chat
hf_config_path: null
tokenizer_path: null
use_shm: false
trust_remote_code: false
custom_chat_template: null
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
enable_activation_offload: false
use_remove_padding: true
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
exclude_modules: null
lora_adapter_path: null
use_liger: false
use_fused_kernels: false
fused_kernel_options:
impl_backend: torch
tiled_mlp:
enabled: false
num_shards: 4
mtp:
_target_: verl.workers.config.MtpConfig
enable: false
enable_train: false
enable_rollout: false
detach_encoder: false
mtp_loss_scaling_factor: 0.1
speculative_algorithm: EAGLE
speculative_num_steps: 2
speculative_eagle_topk: 2
speculative_num_draft_tokens: 4
method: mtp
num_speculative_tokens: 1
_target_: verl.workers.config.FSDPDistillationConfig
enabled: false
loss_mode: k3
topk: 32
use_policy_loss: false
distillation_loss_coef: 1.0
jsd_beta: 0.5
loss_max_clamp: null
log_prob_min_clamp: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
offload_policy: false
reshard_after_forward: true
fsdp_size: -1
forward_prefetch: false
model_dtype: fp32
use_orig_params: false
seed: 42
full_determinism: false
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
use_torch_compile: true
entropy_checkpointing: false
forward_only: true
strategy: fsdp
dtype: bfloat16
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
entropy_from_logits_with_chunking: false
entropy_checkpointing: false
hybrid_engine: true
nccl_timeout: 600
data:
Expand Down
40 changes: 40 additions & 0 deletions verl/trainer/config/distillation/distillation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.DistillationConfig

# specify the default per-component configs
defaults:

# distillation config, inheriting from trainer/config/ref/ref.yaml
# Use absolute path + @_here_ to flatten fields at current level
- /ref/ref@_here_

# Model config for teacher
- ../model@teacher_model: hf_model

# load the default config, then apply the fields in the current yaml
# self config override anything above
- _self_

# Whether to enable distillation.
enabled: false

# Loss function mode for distillation
loss_mode: k3

# If distillation loss requires top-k logits, this is the value of k
topk: 32

# Whether the loss should just be the distillation loss or a combination of the distillation loss and the policy loss
use_policy_loss: false

# The coef of the distillation loss when use_policy_loss is true
distillation_loss_coef: 1.0

# Jensen-Shannon Divergence beta for jsd loss mode.
jsd_beta: 0.5

# Optional max clamp value for distillation loss. If null, no clamping is applied.
loss_max_clamp: null

# Optional min clamp value for log probabilities for stability, e.g., log q - log p where p or q are very close to zero.
log_prob_min_clamp: null
30 changes: 30 additions & 0 deletions verl/trainer/config/distillation/dp_distillation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# defaults specify the default config from each component
defaults:

# dp distillation config, inheriting from trainer/config/distillation/distillation.yaml
- distillation

# fsdp engine config
- ../engine@fsdp_config: fsdp

# load the default config, then apply the fields in the current yaml
- _self_

# Target class for this configuration
_target_: verl.workers.config.FSDPDistillationConfig

# fsdp config
fsdp_config:

# distillation model is forward only
forward_only: True

# sequence parallel size
# same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}

# calculate entropy with chunking to reduce memory peak
entropy_from_logits_with_chunking: False

# recompute entropy
entropy_checkpointing: False
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ defaults:
# Rollout correction config.
- algorithm@algorithm.rollout_correction: rollout_correction

# distillation config
- distillation@actor_rollout_ref.distillation: dp_distillation

# load the reference default config, then apply the fields in the current yaml
# self config override anything above
- _self_
Expand Down
Loading
Loading