Skip to content

Commit

Permalink
revert change.
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Dec 26, 2024
1 parent 9b2e36d commit 2940576
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 22 deletions.
6 changes: 4 additions & 2 deletions chatlearn/models/vllm/hooks/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from vllm import envs
from vllm.executor.ray_gpu_executor import RayGPUExecutor
from vllm.executor.ray_utils import RayWorkerWrapper, ray

from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id)
Expand Down Expand Up @@ -108,6 +107,7 @@ def sort_by_driver_then_worker_ip(worker):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
# worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")

node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
Expand Down Expand Up @@ -156,7 +156,7 @@ def sort_by_driver_then_worker_ip(worker):
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
all_args=self._get_env_vars_to_be_updated())
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
Expand Down Expand Up @@ -186,6 +186,7 @@ def sort_by_driver_then_worker_ip(worker):
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
Expand Down Expand Up @@ -217,4 +218,5 @@ def sort_by_driver_then_worker_ip(worker):
else:
self.non_driver_workers.append(worker)


RayGPUExecutor._init_workers_ray = _init_workers_ray
1 change: 1 addition & 0 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def add_extra_args(self, parser):
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
return parser

def init(self):
"""
:meta private:
Expand Down
2 changes: 2 additions & 0 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal):
if self.num_src_tensor_parallel % 2 == 1 and self.num_dst_tensor_parallel % 2 == 1:
logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead")
self._comm_type = PARAM_SYNC_COMM_TYPE.P2P

self.concurrent_comm = get_args().runtime_args.concurrent_comm
self._enable_lora = self.src_model.module_args.lora.enable_lora
# sync every n episodes, n = 0 for no param sync
Expand Down Expand Up @@ -1128,6 +1129,7 @@ def sync(self, requires_grad=None, validate=False):
actor_mappings_list,
requires_grad=requires_grad
)

assert len(actor_mappings_list) >= 1

self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings)
Expand Down
6 changes: 1 addition & 5 deletions examples/megatron/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
from .reward_inference import RewardInference
from .reference import PolicyReference
try:
from chatlearn.models.vllm import is_vllm_v2
if is_vllm_v2():
from .vllm_policy_inference import VLLMPolicyInferenceV2 as VLLMPolicyInference
else:
from .vllm_policy_inference import VLLMPolicyInference
from .vllm_policy_inference import VLLMPolicyInference
except ImportError:
print("Cannot import VLLMPolicyInference")
VLLMPolicyInference = None
15 changes: 0 additions & 15 deletions examples/megatron/models/vllm_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,3 @@ def decode_internal(self, batched_outputs):
return {"all_tokens": all_tokens, "str_outputs": str_outputs, "str_prompts": str_prompts,
"no_padded_query_ids": no_padded_query_ids, "logprobs": logprobs,
"loss_mask": loss_mask}

class VLLMPolicyInferenceV2(VLLMPolicyInference):
"""VLLMPolicyInferenceV2 is the model for VLLMModuleV2, which uses llm generate API"""

def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return self._forward_step(data, iteration, True)

def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method
outputs = self.generate_vllm(data, is_eval)
if outputs is not None:
rets = self.decode_internal(outputs)
return rets

def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return self._forward_step(data, iteration, False)

0 comments on commit 2940576

Please sign in to comment.