-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[model] feat: add NPU GRPO training scripts for Qwen2.5-32B/Qwen3-30B (Megaton/vLLM backends) #4984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,185 @@ | ||||||
| #!/bin/bash | ||||||
| set -xeuo pipefail | ||||||
| mkdir -p logs | ||||||
|
|
||||||
| # Project Configuration | ||||||
| project_name='GRPO-Qwen2.5-32B-BASE-MATH' | ||||||
| exp_name='GRPO-Qwen2.5-32B-BASE-Megatron-vLLM' | ||||||
|
|
||||||
| # Node Info | ||||||
| NNODES=${NNODES:-1} | ||||||
| NPUS_PER_NODE=${NPUS_PER_NODE:-16} | ||||||
|
|
||||||
| # Model Weights Paths | ||||||
| MODEL_PATH=Qwen/Qwen2.5-32B | ||||||
| MCORE_MODEL_PATH=Qwen/Qwen2.5-32B-dist | ||||||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||||||
| CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} | ||||||
|
|
||||||
| # File System Paths | ||||||
| TRAIN_FILE=$RAY_DATA_HOME/dataset/gsm8k/train.parquet | ||||||
| TEST_FILE=$RAY_DATA_HOME/dataset/gsm8k/test.parquet | ||||||
|
|
||||||
| # Data Configuration | ||||||
| max_prompt_length=$((1024 * 1)) | ||||||
| max_response_length=$((1024 * 1)) | ||||||
|
|
||||||
| # Training Batch Configuration | ||||||
| train_prompt_bsz=128 | ||||||
| train_prompt_mini_bsz=32 | ||||||
| n_resp_per_prompt=16 | ||||||
|
|
||||||
| # Algorithm Configuration | ||||||
| adv_estimator=grpo | ||||||
| use_kl_in_reward=False | ||||||
| kl_coef=0.0 | ||||||
| use_kl_loss=True | ||||||
| kl_loss_coef=0.001 | ||||||
|
|
||||||
| # Performance and Memory Management Configuration | ||||||
| all_offload=True | ||||||
| use_dynamic_bsz=True | ||||||
| actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) | ||||||
| infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 8)) | ||||||
| optimizer_offload_fraction=1 | ||||||
|
|
||||||
| # Megatron Configuration | ||||||
| train_tp=4 | ||||||
| train_ep=1 | ||||||
| train_etp=1 | ||||||
| train_pp=4 | ||||||
| train_cp=1 | ||||||
|
|
||||||
| # vLLM Configuration | ||||||
| gen_tp=2 | ||||||
| gen_dp=1 | ||||||
| gen_ep=1 | ||||||
| gpu_memory_utilization=0.8 | ||||||
| max_model_len=$((max_prompt_length + max_response_length)) | ||||||
| max_num_batched_tokens=$(((max_prompt_length + max_response_length) * 1)) | ||||||
|
|
||||||
| # Data Configuration | ||||||
| DATA_CONFIG=( | ||||||
| data.train_files="${TRAIN_FILE}" | ||||||
| data.val_files="${TEST_FILE}" | ||||||
| data.prompt_key=prompt | ||||||
| data.train_batch_size=${train_prompt_bsz} | ||||||
| data.max_prompt_length=${max_prompt_length} | ||||||
| data.max_response_length=${max_response_length} | ||||||
| data.filter_overlong_prompts=False | ||||||
| data.truncation='left' | ||||||
| ) | ||||||
|
|
||||||
| # Model Configuration | ||||||
| MODEL_CONFIG=( | ||||||
| actor_rollout_ref.model.path="${MODEL_PATH}" | ||||||
| actor_rollout_ref.model.use_remove_padding=True | ||||||
| ) | ||||||
|
|
||||||
| # Algorithm Configuration | ||||||
| ALGORITHM_CONFIG=( | ||||||
| algorithm.adv_estimator=${adv_estimator} | ||||||
| algorithm.use_kl_in_reward=${use_kl_in_reward} | ||||||
| algorithm.kl_ctrl.kl_coef=${kl_coef} | ||||||
| ) | ||||||
|
|
||||||
| # Actor Model Configuration | ||||||
| ACTOR_CONFIG=( | ||||||
| actor_rollout_ref.actor.use_torch_compile=False | ||||||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} | ||||||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} | ||||||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl | ||||||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} | ||||||
| actor_rollout_ref.actor.entropy_coeff=0 | ||||||
| actor_rollout_ref.actor.ppo_epochs=1 | ||||||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 | ||||||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} | ||||||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} | ||||||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl | ||||||
| actor_rollout_ref.actor.optim.lr=1e-6 | ||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} | ||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True | ||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True | ||||||
| actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} | ||||||
| actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} | ||||||
| actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} | ||||||
| actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} | ||||||
| actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${train_etp} | ||||||
| actor_rollout_ref.actor.megatron.param_offload=${all_offload} | ||||||
| actor_rollout_ref.actor.megatron.optimizer_offload=${all_offload} | ||||||
| actor_rollout_ref.actor.megatron.grad_offload=${all_offload} | ||||||
| actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} | ||||||
| actor_rollout_ref.actor.megatron.use_dist_checkpointing=False | ||||||
| +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line has a
Suggested change
|
||||||
| ) | ||||||
|
|
||||||
| # Reference Model Configuration | ||||||
| REF_CONFIG=( | ||||||
| actor_rollout_ref.ref.use_torch_compile=False | ||||||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 | ||||||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} | ||||||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} | ||||||
| actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} | ||||||
| actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} | ||||||
| actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} | ||||||
| actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} | ||||||
| actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${train_etp} | ||||||
| actor_rollout_ref.ref.megatron.param_offload=${all_offload} | ||||||
| actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} | ||||||
| actor_rollout_ref.ref.megatron.use_dist_checkpointing=False | ||||||
| ) | ||||||
|
|
||||||
| # Rollout Configuration | ||||||
| ROLLOUT_CONFIG=( | ||||||
| actor_rollout_ref.rollout.name=vllm | ||||||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} | ||||||
| actor_rollout_ref.rollout.top_p=1.0 | ||||||
| actor_rollout_ref.rollout.top_k=-1 | ||||||
| actor_rollout_ref.rollout.temperature=1.0 | ||||||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 | ||||||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} | ||||||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} | ||||||
| actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} | ||||||
| actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} | ||||||
| actor_rollout_ref.rollout.max_model_len=${max_model_len} | ||||||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} | ||||||
| actor_rollout_ref.rollout.data_parallel_size=${gen_dp} | ||||||
| actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} | ||||||
| actor_rollout_ref.rollout.enable_chunked_prefill=True | ||||||
| actor_rollout_ref.rollout.enable_prefix_caching=True | ||||||
| actor_rollout_ref.rollout.enforce_eager=True | ||||||
| actor_rollout_ref.rollout.free_cache_engine=True | ||||||
| actor_rollout_ref.rollout.val_kwargs.n=1 | ||||||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True | ||||||
| actor_rollout_ref.rollout.val_kwargs.top_p=1.0 | ||||||
| actor_rollout_ref.rollout.val_kwargs.top_k=-1 | ||||||
| actor_rollout_ref.rollout.val_kwargs.temperature=1.0 | ||||||
| ) | ||||||
|
|
||||||
| # Trainer Configuration | ||||||
| TRAINER_CONFIG=( | ||||||
| trainer.logger='["console","tensorboard"]' | ||||||
| trainer.project_name="${project_name}" | ||||||
| trainer.experiment_name="${exp_name}" | ||||||
| trainer.nnodes="${NNODES}" | ||||||
| trainer.n_gpus_per_node="${NPUS_PER_NODE}" | ||||||
| trainer.device='npu' | ||||||
| trainer.total_epochs=15 | ||||||
| trainer.val_before_train=False | ||||||
| trainer.test_freq=-1 | ||||||
| trainer.save_freq=-1 | ||||||
| trainer.default_local_dir="${CKPTS_DIR}" | ||||||
| ) | ||||||
|
|
||||||
| # Main GRPO Training Command | ||||||
| python3 -m verl.trainer.main_ppo \ | ||||||
| --config-path=config \ | ||||||
| --config-name='ppo_megatron_trainer.yaml' \ | ||||||
| "${DATA_CONFIG[@]}" \ | ||||||
| "${MODEL_CONFIG[@]}" \ | ||||||
| "${ACTOR_CONFIG[@]}" \ | ||||||
| "${REF_CONFIG[@]}" \ | ||||||
| "${ROLLOUT_CONFIG[@]}" \ | ||||||
| "${ALGORITHM_CONFIG[@]}" \ | ||||||
| "${TRAINER_CONFIG[@]}" \ | ||||||
| "$@" | tee logs/run_qwen2_5-32b_grpo_megatron_vllm_npu.log | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,188 @@ | ||||||||||||||||||
| #!/bin/bash | ||||||||||||||||||
| set -xeuo pipefail | ||||||||||||||||||
| mkdir -p logs | ||||||||||||||||||
|
|
||||||||||||||||||
| # Project Configuration | ||||||||||||||||||
| project_name='GRPO-Qwen3-30b-A3B-BASE-MATH' | ||||||||||||||||||
| exp_name='GRPO-Qwen3-30B-A3B-BASE-Megatron-vLLM' | ||||||||||||||||||
|
|
||||||||||||||||||
| # Node Info | ||||||||||||||||||
| NNODES=${NNODES:-1} | ||||||||||||||||||
| NPUS_PER_NODE=${NPUS_PER_NODE:-16} | ||||||||||||||||||
|
|
||||||||||||||||||
| # Model Weights Paths | ||||||||||||||||||
| MODEL_PATH=Qwen/Qwen3-30B-A3B-Base | ||||||||||||||||||
| MCORE_MODEL_PATH=Qwen/Qwen3-30B-A3B-Base-dist | ||||||||||||||||||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||||||||||||||||||
| CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} | ||||||||||||||||||
|
|
||||||||||||||||||
| # File System Paths | ||||||||||||||||||
| TRAIN_FILE=$RAY_DATA_HOME/dataset/gsm8k/train.parquet | ||||||||||||||||||
| TEST_FILE=$RAY_DATA_HOME/dataset/gsm8k/test.parquet | ||||||||||||||||||
|
|
||||||||||||||||||
| # Data Configuration | ||||||||||||||||||
| max_prompt_length=$((1024 * 1)) | ||||||||||||||||||
| max_response_length=$((1024 * 1)) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Training Batch Configuration | ||||||||||||||||||
| train_prompt_bsz=128 | ||||||||||||||||||
| train_prompt_mini_bsz=32 | ||||||||||||||||||
| n_resp_per_prompt=16 | ||||||||||||||||||
|
|
||||||||||||||||||
| # Algorithm Configuration | ||||||||||||||||||
| adv_estimator=grpo | ||||||||||||||||||
| use_kl_in_reward=False | ||||||||||||||||||
| kl_coef=0.0 | ||||||||||||||||||
| use_kl_loss=True | ||||||||||||||||||
| kl_loss_coef=0.001 | ||||||||||||||||||
|
|
||||||||||||||||||
| # Performance and Memory Management Configuration | ||||||||||||||||||
| all_offload=True | ||||||||||||||||||
| use_dynamic_bsz=True | ||||||||||||||||||
| actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 4)) | ||||||||||||||||||
| infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 8)) | ||||||||||||||||||
| optimizer_offload_fraction=1 | ||||||||||||||||||
|
|
||||||||||||||||||
| # Megatron Configuration | ||||||||||||||||||
| train_tp=2 | ||||||||||||||||||
| train_ep=8 | ||||||||||||||||||
| train_etp=1 | ||||||||||||||||||
| train_pp=2 | ||||||||||||||||||
| train_cp=1 | ||||||||||||||||||
|
|
||||||||||||||||||
| # vLLM Configuration | ||||||||||||||||||
| gen_tp=2 | ||||||||||||||||||
| gen_dp=1 | ||||||||||||||||||
| gen_ep=1 | ||||||||||||||||||
| gpu_memory_utilization=0.8 | ||||||||||||||||||
| max_model_len=$((max_prompt_length + max_response_length)) | ||||||||||||||||||
| max_num_batched_tokens=$(((max_prompt_length + max_response_length) * 1)) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Data Configuration | ||||||||||||||||||
| DATA_CONFIG=( | ||||||||||||||||||
| data.train_files="${TRAIN_FILE}" | ||||||||||||||||||
| data.val_files="${TEST_FILE}" | ||||||||||||||||||
| data.prompt_key=prompt | ||||||||||||||||||
| data.train_batch_size=${train_prompt_bsz} | ||||||||||||||||||
| data.max_prompt_length=${max_prompt_length} | ||||||||||||||||||
| data.max_response_length=${max_response_length} | ||||||||||||||||||
| data.filter_overlong_prompts=False | ||||||||||||||||||
| data.truncation='left' | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Model Configuration | ||||||||||||||||||
| MODEL_CONFIG=( | ||||||||||||||||||
| actor_rollout_ref.model.path="${MODEL_PATH}" | ||||||||||||||||||
| actor_rollout_ref.model.use_remove_padding=True | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Algorithm Configuration | ||||||||||||||||||
| ALGORITHM_CONFIG=( | ||||||||||||||||||
| algorithm.adv_estimator=${adv_estimator} | ||||||||||||||||||
| algorithm.use_kl_in_reward=${use_kl_in_reward} | ||||||||||||||||||
| algorithm.kl_ctrl.kl_coef=${kl_coef} | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Actor Model Configuration | ||||||||||||||||||
| ACTOR_CONFIG=( | ||||||||||||||||||
| actor_rollout_ref.actor.use_torch_compile=False | ||||||||||||||||||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} | ||||||||||||||||||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} | ||||||||||||||||||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl | ||||||||||||||||||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} | ||||||||||||||||||
| actor_rollout_ref.actor.entropy_coeff=0 | ||||||||||||||||||
| actor_rollout_ref.actor.ppo_epochs=1 | ||||||||||||||||||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 | ||||||||||||||||||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} | ||||||||||||||||||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} | ||||||||||||||||||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl | ||||||||||||||||||
| actor_rollout_ref.actor.optim.lr=1e-6 | ||||||||||||||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} | ||||||||||||||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True | ||||||||||||||||||
| +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True | ||||||||||||||||||
|
Comment on lines
+100
to
+102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines have a
Suggested change
|
||||||||||||||||||
| actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${train_etp} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.param_offload=${all_offload} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.optimizer_offload=${all_offload} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.grad_offload=${all_offload} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} | ||||||||||||||||||
| actor_rollout_ref.actor.megatron.use_dist_checkpointing=False | ||||||||||||||||||
| +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True | ||||||||||||||||||
| +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform | ||||||||||||||||||
| +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full | ||||||||||||||||||
| +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 | ||||||||||||||||||
|
Comment on lines
+113
to
+116
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines have a
Suggested change
|
||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Reference Model Configuration | ||||||||||||||||||
| REF_CONFIG=( | ||||||||||||||||||
| actor_rollout_ref.ref.use_torch_compile=False | ||||||||||||||||||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 | ||||||||||||||||||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} | ||||||||||||||||||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${train_etp} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.param_offload=${all_offload} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} | ||||||||||||||||||
| actor_rollout_ref.ref.megatron.use_dist_checkpointing=False | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Rollout Configuration | ||||||||||||||||||
| ROLLOUT_CONFIG=( | ||||||||||||||||||
| actor_rollout_ref.rollout.name=vllm | ||||||||||||||||||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} | ||||||||||||||||||
| actor_rollout_ref.rollout.top_p=1.0 | ||||||||||||||||||
| actor_rollout_ref.rollout.top_k=-1 | ||||||||||||||||||
| actor_rollout_ref.rollout.temperature=1.0 | ||||||||||||||||||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 | ||||||||||||||||||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} | ||||||||||||||||||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} | ||||||||||||||||||
| actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} | ||||||||||||||||||
| actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} | ||||||||||||||||||
| actor_rollout_ref.rollout.max_model_len=${max_model_len} | ||||||||||||||||||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} | ||||||||||||||||||
| actor_rollout_ref.rollout.data_parallel_size=${gen_dp} | ||||||||||||||||||
| actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} | ||||||||||||||||||
| actor_rollout_ref.rollout.enable_chunked_prefill=True | ||||||||||||||||||
| actor_rollout_ref.rollout.enable_prefix_caching=True | ||||||||||||||||||
| actor_rollout_ref.rollout.enforce_eager=True | ||||||||||||||||||
| actor_rollout_ref.rollout.free_cache_engine=True | ||||||||||||||||||
| actor_rollout_ref.rollout.val_kwargs.n=1 | ||||||||||||||||||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True | ||||||||||||||||||
| actor_rollout_ref.rollout.val_kwargs.top_p=1.0 | ||||||||||||||||||
| actor_rollout_ref.rollout.val_kwargs.top_k=-1 | ||||||||||||||||||
| actor_rollout_ref.rollout.val_kwargs.temperature=1.0 | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Trainer Configuration | ||||||||||||||||||
| TRAINER_CONFIG=( | ||||||||||||||||||
| trainer.logger='["console","tensorboard"]' | ||||||||||||||||||
| trainer.project_name="${project_name}" | ||||||||||||||||||
| trainer.experiment_name="${exp_name}" | ||||||||||||||||||
| trainer.nnodes="${NNODES}" | ||||||||||||||||||
| trainer.n_gpus_per_node="${NPUS_PER_NODE}" | ||||||||||||||||||
| trainer.device='npu' | ||||||||||||||||||
| trainer.total_epochs=15 | ||||||||||||||||||
| trainer.val_before_train=False | ||||||||||||||||||
| trainer.test_freq=-1 | ||||||||||||||||||
| trainer.save_freq=-1 | ||||||||||||||||||
| trainer.default_local_dir="${CKPTS_DIR}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Main GRPO Training Command | ||||||||||||||||||
| python3 -m verl.trainer.main_ppo \ | ||||||||||||||||||
| --config-path=config \ | ||||||||||||||||||
| --config-name='ppo_megatron_trainer.yaml' \ | ||||||||||||||||||
| "${DATA_CONFIG[@]}" \ | ||||||||||||||||||
| "${MODEL_CONFIG[@]}" \ | ||||||||||||||||||
| "${ACTOR_CONFIG[@]}" \ | ||||||||||||||||||
| "${REF_CONFIG[@]}" \ | ||||||||||||||||||
| "${ROLLOUT_CONFIG[@]}" \ | ||||||||||||||||||
| "${ALGORITHM_CONFIG[@]}" \ | ||||||||||||||||||
| "${TRAINER_CONFIG[@]}" \ | ||||||||||||||||||
| "$@" | tee logs/run_qwen3moe-30b_grpo_megatron_vllm_npu.log | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines have a
+prefix, which is invalid syntax for a bash array assignment. This appears to be a copy-paste error from a diff and will cause the script to fail. Please remove the+prefixes.