diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e6d838df5adf..28064c63de67 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1486,7 +1486,6 @@ def _configure_distributed_model(self, model): summary += f"\t {self.sequence_parallel_size=}\n" summary += "***********************************************" logger.info(summary) - if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() @@ -1703,7 +1702,6 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = MuSGD(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUON_OPTIMIZER: zero_stage = self.zero_optimization_stage() - assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3" if not all([hasattr(p, 'use_muon') for p in model_parameters]): msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \ "please set by `param.use_muon = True / False` for all params" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4f2a19e7431a..41d4b541eb3e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -36,6 +36,8 @@ from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.muon.original_muon import muon_update +from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -212,6 +214,7 @@ def __init__( raise SystemError("Cannot use fp16 without accelerator.") self.optimizer = init_optimizer + self.param_names = param_names # Use torch (un)flatten ops @@ -329,12 +332,15 @@ def _enforce_optimizer_offload(): self.all2all_process_group = all2all_process_group self.reduce_scatter = reduce_scatter - + self.use_muon = isinstance(self.optimizer, MuonWithAuxAdam) + self.save_muon_momentum_buffer_in_memory = ds_config.get('save_muon_momentum_buffer_in_memory', False) + if self.use_muon and self.reduce_scatter: + raise ValueError("Muon and reduce scatter cannot be used together") + if self.use_muon and self.all2all_process_group is not None: + raise ValueError("Muon and all2all process group cannot be used together") self.dp_process_group = self.parameter_offload.dp_process_group self.sequence_parallel_size = groups._get_sequence_parallel_world_size() - self.all2all_process_group = all2all_process_group - self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights self.partition_count = dist.get_world_size(group=self.dp_process_group) @@ -385,6 +391,8 @@ def _enforce_optimizer_offload(): #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.fp32_partitioned_groups_flat = [] + if self.use_muon and self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat = [] self.next_swappable_fp32_partitioned_groups = [] # number of elements per partition in each group @@ -780,6 +788,14 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): param_groups: List[List[Parameter]] = tuple( self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups) + if self.use_muon: + self.sub_groups_using_muon = [] + for idx, param_group in enumerate(fp16_param_groups): + if getattr(param_group['params'][0], 'use_muon', False): + self.sub_groups_using_muon.extend([True] * len(param_groups[idx])) + self.muon_beta = param_group['momentum'] + else: + self.sub_groups_using_muon.extend([False] * len(param_groups[idx])) # bookkeeping related to param groups for param_group_idx, param_group in enumerate(param_groups): for sub_group in param_group: @@ -907,6 +923,20 @@ def _get_sub_group_partitions(self, sub_group_id): return sub_group_partitions + def _create_momentum_buffer(self, num_elements, i, ds_id): + if self.use_muon and self.sub_groups_using_muon[i]: + unpinned_fp32_buffer_momentum = torch.zeros(num_elements, + device=self.device, + dtype=self.communication_data_type) + unpinned_fp32_buffer_momentum.requires_grad = False + if self.fp32_partitioned_groups_flat[i] not in self.optimizer.state: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]] = {} + self.optimizer.state[ + self.fp32_partitioned_groups_flat[i]]["momentum_buffer"] = unpinned_fp32_buffer_momentum + if self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat.append(unpinned_fp32_buffer_momentum) + self.muon_momentum_buffer_partitioned_groups_flat[i].ds_id = ds_id + def _create_fp32_partitions(self): cpu_memory_usage = 0 cpu_memory_sub_groups = 0 @@ -948,6 +978,7 @@ def _create_fp32_partitions(self): self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 + self._create_momentum_buffer(num_elements, i, ds_id) if self.params_in_nvme_and_cpu and tensor is None: num_swap_from_nvme_partitions += 1 @@ -979,20 +1010,24 @@ def _create_fp32_partitions(self): dtype=self.master_weights_and_grads_dtype) self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + self._create_momentum_buffer(num_elements, i, ds_id) elif self.offload_optimizer: converted = self.fp16_partitioned_groups_flat[i].to(self.subgroup_to_device[i], dtype=self.master_weights_and_grads_dtype) self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) elif self.fp16_partitioned_groups_flat[i].dtype == self.master_weights_and_grads_dtype and \ self.fp16_partitioned_groups_flat[i].device == self.device: # When torch autocast is enabled, weights in the provided model (and thus groups in the so-called # "fp16" partitioned groups) are already in and updated using fp32. In such cases we don't need # another copy of the weights. self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i]) + self._create_momentum_buffer(num_elements, i, ds_id) else: converted = self.fp16_partitioned_groups_flat[i].to(self.device, dtype=self.master_weights_and_grads_dtype) self.fp32_partitioned_groups_flat.append(converted.clone().detach()) + self._create_momentum_buffer(num_elements, i, ds_id) self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it @@ -1420,6 +1455,119 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) event.record() self.param_reduce_events.append(event) + def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, buffer_to_reduce: Tensor): + """ + Update the momentum buffer of the parameters using muon. + Args: + communication_data_type: torch.dtype + buffer_to_reduce: Tensor + Returns: + None + """ + momentum_buffer = [] + use_muon_params = [] + params_to_subgroup_maps = {} + idx = 0 + # find the parameters that need to be updated using muon and they will be indexed by the subgroups + # this is done since the parameters are swapped in and out to nvme by the subgroups + params_size_offset = 0 + param_grad_offsets = {} + for param in self.ipg_buckets[communication_data_type].params: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + if self.use_muon and self.sub_groups_using_muon[i]: + use_muon_params.append(param) + param_grad_offsets[param] = params_size_offset + # copy the gradients back to the params in the ipg bucket for the muon update + param.grad.data.copy_(buffer_to_reduce.narrow(0, params_size_offset, + param.grad.numel()).view_as(param.grad), + non_blocking=False) + momentum_buffer.append(None) + if i not in params_to_subgroup_maps: + params_to_subgroup_maps[i] = [] + params_to_subgroup_maps[i].append((idx, dest_offset)) + idx += 1 + params_size_offset += param.grad.numel() + # if optimizer is swappable, swap in the momentum buffer of the parameters that need to be updated using muon and then swap them out + # if optimizer is not swappable, find the momentum buffer of the parameters that need to be updated using muon in memory + for i in params_to_subgroup_maps: + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + for idx, dest_offset in params_to_subgroup_maps[i]: + momentum_buffer[idx] = self.optimizer.state[ + self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, use_muon_params[idx].partition_numel()).clone() + self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + elif self.save_muon_momentum_buffer_in_memory: + for idx, dest_offset in params_to_subgroup_maps[i]: + momentum_buffer[idx] = self.muon_momentum_buffer_partitioned_groups_flat[i].narrow( + 0, dest_offset, use_muon_params[idx].partition_numel()).clone() + else: + raise ValueError( + "Invalid momentum buffer save mode, momentum buffer should be saved in memory or swapped in and out to nvme" + ) + # if there are parameters that need to be updated using muon + if momentum_buffer: + # all gather the momentum buffers of the parameters to the global buffer + # this is done since the momentum buffers are stored in partitions just like the params themselves + gathered_params_momentums = self._partitioned_buffers_all_gather(use_muon_params, momentum_buffer, + communication_data_type) + for i in params_to_subgroup_maps: + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + + # in the case of large numbers of parameters, we distribute the workload across the ranks + # because muon update is a heavy operation + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + params = [use_muon_params[idx] for idx, _ in params_to_subgroup_maps[i]] + gathered_momentums = [gathered_params_momentums[idx] for idx, _ in params_to_subgroup_maps[i]] + # params_pad = params + [torch.empty_like(params[-1])] * (world_sz - len(params) % world_sz) + grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * ( + (world_sz - len(params) % world_sz) % world_sz) + gathered_momentums_pad = gathered_momentums + [torch.empty_like(gathered_momentums[-1])] * ( + (world_sz - len(gathered_momentums) % world_sz) % world_sz) + for base_i in range(len(params))[::world_sz]: + if base_i + rank < len(params): + param = params[base_i + rank] + g = param.grad + m = gathered_momentums_pad[base_i + rank] + update = muon_update(g, m, beta=self.muon_beta) + g.data.copy_(update, non_blocking=False) + buffer_to_reduce.narrow(0, param_grad_offsets[param], + param.grad.numel()).data.copy_(g.view(-1), non_blocking=False) + dist.all_gather(grads_pad[base_i:base_i + world_sz], grads_pad[base_i + rank]) + dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz], + gathered_momentums_pad[base_i + rank]) + # now each rank has the full momentum buffers updated as well as the gradients updated + # then write them backt to the optimizer state + for idx, dest_offset in params_to_subgroup_maps[i]: + param = use_muon_params[idx] + gathered_momentum = gathered_params_momentums[idx] + chunk_sz = math.ceil(param.grad.numel() / world_sz) + start_offset = rank * chunk_sz + end_offset = start_offset + chunk_sz + if end_offset > param.grad.numel(): + buffer_to_update = torch.zeros(chunk_sz, device=param.grad.device, dtype=param.grad.dtype) + buffer_to_update[:param.grad.numel() - start_offset] = gathered_momentum.view( + -1).data[start_offset:param.grad.numel()] + else: + buffer_to_update = gathered_momentum.view(-1).data[start_offset:end_offset] + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer.state[self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + elif self.save_muon_momentum_buffer_in_memory: + self.muon_momentum_buffer_partitioned_groups_flat[i].narrow( + 0, dest_offset, param.partition_numel()).data.copy_(buffer_to_update, non_blocking=False) + # update the momentum buffer in the optimizer state + self.optimizer.state[self.fp32_partitioned_groups_flat[i]][ + "momentum_buffer"] = self.muon_momentum_buffer_partitioned_groups_flat[i] + else: + raise ValueError( + "Invalid momentum buffer save mode, momentum buffer should be saved in memory or swapped in and out to nvme" + ) + if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: + self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) + @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, communication_data_type: torch.dtype) -> List[Tensor]: @@ -1443,6 +1591,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor, grad_partitions = [] grad_offset_in_buffer = 0 + self._apply_distributed_muon_update(communication_data_type, buffer_to_reduce) for param in self.ipg_buckets[communication_data_type].params: grad = param.grad chunk_sz = math.ceil(grad.numel() / world_sz) @@ -1617,6 +1766,54 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L gradient_tensors=offload_fp32_gradients[i]) return buffers + def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_allgather: List[Tensor], + communication_data_type: torch.dtype): + """ + Allgather the partitioned buffers of the parameters to the global buffer. + Args: + params: List[Parameter] + buffers_to_allgather: List[Tensor] + communication_data_type: torch.dtype + Returns: + List[Tensor] + """ + # assert False, "check entrance of _partitioned_buffers_all_gather" + assert len(params) == len(buffers_to_allgather), "params and buffers_to_allgather must have the same length" + assert all(param.partition_numel() == buffer.numel() for param, buffer in zip( + params, buffers_to_allgather)), "params and buffers_to_allgather must have the same numel" + coalesced_buffer = instrument_w_nvtx(torch.cat)(buffers_to_allgather) + buffer_numel = coalesced_buffer.numel() + reduce_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + rearrange_buffer = torch.empty(self.partition_count * buffer_numel, + dtype=communication_data_type, + device=params[0].device) + my_rank = dist.get_rank(group=self.dp_process_group) + partition = reduce_buffer.narrow(0, buffer_numel * my_rank, buffer_numel) + partition.data.copy_(coalesced_buffer.data, non_blocking=False) + dist.all_gather_into_tensor(reduce_buffer, partition, group=self.dp_process_group) + param_partition_offsets = [0] + rearranged_offset = 0 + for idx, param in enumerate(params): + param_partition_offsets.append(param_partition_offsets[idx] + param.partition_numel()) + for idx, param in enumerate(params): + num_elements = param.partition_numel() + for partition_idx in range(self.partition_count): + sliced = reduce_buffer.narrow(0, buffer_numel * partition_idx + param_partition_offsets[idx], + num_elements) + rearrange_buffer.narrow(0, rearranged_offset, num_elements).copy_(sliced.data, non_blocking=False) + rearranged_offset += num_elements + param_full_offsets = [0] + for idx, param in enumerate(params): + # the offset is the sum of the numel of all the partitions of the parameter including padding + param_full_offsets.append(param_full_offsets[idx] + + buffers_to_allgather[idx].numel() * self.partition_count) + output = [] + for idx, param in enumerate(params): + output.append(rearrange_buffer.narrow(0, param_full_offsets[idx], param.ds_numel).view(param.ds_shape)) + return output + def reduce_ready_partitions_and_remove_grads(self, param): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) self.reduce_independent_p_g_buckets_and_remove_grads(param) diff --git a/tests/unit/ops/muon/test_muon.py b/tests/unit/ops/muon/test_muon.py index f12cbb358a82..c71363ddf14e 100644 --- a/tests/unit/ops/muon/test_muon.py +++ b/tests/unit/ops/muon/test_muon.py @@ -17,17 +17,18 @@ muon_configs = [] for optimizer_name in ['muon', 'adam']: - for stage in [1, 2]: + for stage in [1, 2, 3]: for lr in [0.01, 0.05]: for model_dim in [32, 128]: for nlayer in [5, 10]: - muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer]) + for offload_optimizer in [True, False]: + muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer, offload_optimizer]) -@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs) +@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer', muon_configs) class TestMuonConfigs(DistributedTest): - def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer): + def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optimizer): optimizer_params = {"lr": lr} batch_size = 8 config_dict = { @@ -42,8 +43,16 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer): }, "zero_optimization": { "stage": zero_stage, - } + "reduce_scatter": False, + }, + "save_muon_momentum_buffer_in_memory": True, } + if offload_optimizer: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + # Perform a few training steps to ensure the optimizer works correctly model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer)