Skip to content

Commit

Permalink
revert yaml.
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Dec 26, 2024
1 parent 2940576 commit 3351522
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 47 deletions.
4 changes: 1 addition & 3 deletions examples/megatron/configs/llama2/vllm_param_sync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ models:

runtime:
colocation:
- policy
- ppo_policy
- policy,ppo_policy
generation_batch_size: ${generation_batch_size:4}
train_micro_batch_size: ${train_micro_batch_size:2}
train_global_batch_size: ${train_global_batch_size:512}
Expand All @@ -50,4 +49,3 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
param_sync_comm_type: ${param_sync_comm_type:broadcast}
4 changes: 1 addition & 3 deletions examples/megatron/configs/llama2/vllm_rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ models:

runtime:
colocation:
- policy
- ppo_policy,reward,reference,value,ppo_value
- policy,ppo_policy,reward,reference,value,ppo_value
generation_batch_size: ${generation_batch_size:4}
train_micro_batch_size: ${train_micro_batch_size:2}
train_global_batch_size: ${train_global_batch_size:512}
Expand All @@ -83,4 +82,3 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
param_sync_comm_type: ${param_sync_comm_type:broadcast}
1 change: 0 additions & 1 deletion examples/megatron/models/vllm_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def decode_internal(self, batched_outputs):
prompt_sizes = torch.tensor([len(q) for q in no_padded_query_ids], device=all_tokens.device)
loss_mask = get_loss_mask(all_tokens, self.tokenizer.tokenizer.eos_token_id, prompt_sizes)
loss_mask = loss_mask.to("cpu")
print(f"str_outputs: {len(str_outputs)} {str_outputs}")
return {"all_tokens": all_tokens, "str_outputs": str_outputs, "str_prompts": str_prompts,
"no_padded_query_ids": no_padded_query_ids, "logprobs": logprobs,
"loss_mask": loss_mask}
49 changes: 17 additions & 32 deletions examples/megatron/scripts/train_rlhf_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,46 +42,31 @@ export data_checkpoint_path=${output_dir}/data_checkpoint


if [[ "$model_size" == "llama2-7B" ]]; then
export policy_tp=4
export policy_pp=1

# export policy_tp=8
# export policy_pp=1

export ppo_policy_tp=4
export ppo_policy_pp=1
# export ppo_policy_tp=4
# export ppo_policy_pp=2

# export policy_tp=8
# export policy_pp=1

# export ppo_policy_tp=4
# export ppo_policy_pp=1
# export ppo_policy_tp=2
# export ppo_policy_pp=1

export reward_tp=4
export ppo_value_pp=1
[ -z "$policy_tp" ] && export policy_tp=4
[ -z "$policy_pp" ] && export policy_pp=1
[ -z "$ppo_policy_tp" ] && export ppo_policy_tp=4
[ -z "$ppo_policy_pp" ] && export ppo_policy_pp=1
[ -z "$reward_tp" ] && export reward_tp=4
[ -z "$ppo_value_pp" ] && export ppo_value_pp=1
export train_global_batch_size=128
if [[ "$backend" == "megatron" ]]; then
export generation_batch_size=128
export generation_batch_size=256
elif [[ "$backend" == "vllm" ]]; then
export generation_batch_size=128
export generation_batch_size=512
fi
export ref_generation_batch_size=64
export value_generation_batch_size=64
export reward_generation_batch_size=64
export train_micro_batch_size=16
export max_num_batched_tokens=65536
export gpu_memory_utilization=0.5
# export num_gpu_ref=4
# export num_gpu_value=4
# export num_gpu_ppo_policy=4
# export num_gpu_ppo_value=4
# export free_memory_reward=True
# export free_memory_ppo_policy=True
# export free_memory_ppo_value=True
export gpu_memory_utilization=0.9
export num_gpu_ref=4
export num_gpu_value=4
export num_gpu_ppo_policy=4
export num_gpu_ppo_value=4
export free_memory_reward=True
export free_memory_ppo_policy=True
export free_memory_ppo_value=True
elif [[ "$model_size" == "llama2-13B" ]]; then
export policy_tp=8
export policy_pp=1
Expand Down Expand Up @@ -130,4 +115,4 @@ num_gpu=${num_gpu} \
data_path=${DATASET_PATH} \
eval_data_path=${EVAL_DATASET_PATH} \
sample_per_episode=${sample_per_episode} \
python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]}
python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]}
16 changes: 8 additions & 8 deletions examples/megatron/tests/test_unbalanced_param_sync.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ config_dir=${CHATLEARN}/examples/megatron/configs/


if [[ "$model_size" == "llama2-7B" ]]; then
export policy_tp=4
export policy_tp=8
export policy_pp=1
export ppo_policy_tp=4
export ppo_policy_pp=1
export ppo_policy_tp=2
export ppo_policy_pp=4
export train_global_batch_size=128
if [[ "$backend" == "megatron" ]]; then
export generation_batch_size=128
Expand All @@ -38,11 +38,11 @@ if [[ "$model_size" == "llama2-7B" ]]; then
fi
export train_micro_batch_size=16
export max_num_batched_tokens=65536
export gpu_memory_utilization=0.5
export gpu_memory_utilization=0.8

export num_gpu_policy=4
export num_gpu_ppo_policy=4
export free_memory_policy=False
export num_gpu_policy=8
export num_gpu_ppo_policy=8
export free_memory_policy=True
export free_memory_ppo_policy=True
fi

Expand All @@ -55,4 +55,4 @@ num_episode=${num_ppo_episode:-0} \
data_path=${DATASET_PATH} \
eval_data_path=${EVAL_DATASET_PATH} \
sample_per_episode=${sample_per_episode} \
python tests/test_unbalanced_param_sync.py -c $config_file 2>&1 | tee ${output_dir}/log_${RANK}.log ; exit ${PIPESTATUS[0]}
python tests/test_unbalanced_param_sync.py -c $config_file 2>&1 | tee ${output_dir}/log_${RANK}.log ; exit ${PIPESTATUS[0]}

0 comments on commit 3351522

Please sign in to comment.