-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
Some thoughts for improvements
Lines 582 to 608 in 4540740
| for name, param in model.named_parameters(): | |
| count += 1 # empty_cache at last param | |
| # Fire all vllm engines for broadcast | |
| if self.strategy.is_rank_0(): | |
| shape = ( | |
| param.shape | |
| if self.strategy.args.zero_stage != 3 | |
| else param.ds_shape | |
| ) | |
| futs = [ | |
| actor.futures.update_weight( | |
| name, | |
| dtype=torch_type_codec(param.dtype), | |
| shape=shape, | |
| empty_cache=count == num_params, | |
| ) | |
| for actor in self.actors | |
| ] | |
| # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 | |
| with deepspeed.zero.GatheredParameters( | |
| [param], enabled=self.strategy.args.zero_stage == 3 | |
| ): | |
| if self.strategy.is_rank_0(): | |
| dist.broadcast(param.data, 0, group=self._model_update_group) | |
| _ = [fut.result() for fut in futs] |
- Instead of the
fut.result()for each param, would save the dispatch latency if we callupdate_weightandbroadcastfor every single param, and then wait on all the futs. My understanding is that they will be dispatched as a series of nccl calls, and will respect the order they are dispatched. - It may be possible to broadcast different params from different learners, so that the communication bandwidth is maximally used. But with some caveats, 1. we may need different communication groups; 2. we need some coordination mechanism to make sure the
broadcast/update_weightpairs are in the right order. Is it possible to add all the actors to the deepspeed communication group so that they get parameter updates? But without having them participate in the training.
Lines 575 to 579 in 4540740
| while True: | |
| time.sleep(0.1) | |
| actors_busy = [actor.is_generating() for actor in self.actors] | |
| if not any(actors_busy): | |
| break |
- Can we avoid this polling? An idea is to create a nccl broadcast independent of the vllm
update_weightcall. The actors receive weight updates in another thread and cache them. Then in thestepfunction of actor, we check this cache and update the weights when they are available. In this way we maximize the communication computation overlap.
Metadata
Metadata
Assignees
Labels
No labels