diff --git a/megatron/arguments.py b/megatron/arguments.py index fbd0a59053..07c9088e7b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -97,7 +97,9 @@ def validate_args(args, defaults={}): if args.ds_sequence_parallel_size > 1: assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+" - + if args.ds_sequence_parallel_overlap_comm: + assert args.split_qkv_linear, \ + "ds_sequence_parallel_overlap_comm requires split_qkv_linear is True" model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size * \ args.ds_sequence_parallel_size @@ -924,6 +926,9 @@ def _add_training_args(parser): group.add_argument('--disable-moe-top2-2nd-expert-sampling', action='store_false', help='Disable MoE top2 sampling of the 2nd expert. Instead of sampling, use argmax.', dest='moe_top2_2nd_expert_sampling') + group.add_argument('--split-qkv-linear', action='store_true', + help='Separate linear computations for query, key, and value.', + dest='split_qkv_linear') group.add_argument('--use-flash-attn', '--use-flash-attn-v1', dest='use_flash_attn_v1', action='store_true', help='use first version FlashAttention implementation of attention. ' 'https://arxiv.org/abs/2205.14135') @@ -975,14 +980,15 @@ def _add_training_args(parser): help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.') group.add_argument('--force-ds-sequence-parallel', action='store_true', help='use DeepSpeed sequence parallelism regardless of sequence parallel size.') - + group.add_argument('--ds-sequence-parallel-overlap-comm', action='store_true', + help='overlap comm for ds-sequence-parallel', + dest='ds_sequence_parallel_overlap_comm') group.add_argument('--ds-sequence-parallel-fpdt', action='store_true', help='use DeepSpeed sequence parallelism with FPDT.') group.add_argument('--ds-sequence-parallel-fpdt-chunk-size', type=int, default=65536, help='Chunk size used in FPDT attention.') group.add_argument('--ds-sequence-parallel-fpdt-offloading', action='store_true', help='use DeepSpeed sequence parallelism FPDT with offloading.') - group.add_argument('--no-gradient-accumulation-fusion', action='store_false', help='Disable fusing gradient accumulation to weight ' diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 3dd3299ae0..58e1d1f976 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -20,6 +20,7 @@ from megatron import get_args from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore from megatron.core.parallel_state import ( get_tensor_model_parallel_rank, @@ -248,13 +249,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, input, weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, sequence_parallel): + async_grad_allreduce, sequence_parallel, bwd_stream=None): ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel = sequence_parallel - + ctx.bwd_stream = bwd_stream + if sequence_parallel: world_size = get_tensor_model_parallel_world_size() dim_size = list(input.size()) @@ -314,6 +316,7 @@ def backward(ctx, grad_output): total_input = all_gather_buffer else: total_input = input + grad_input = grad_output.matmul(weight) if ctx.sequence_parallel: @@ -368,23 +371,30 @@ def backward(ctx, grad_output): # grad_weight = None # else: # grad_weight = grad_output.t().matmul(total_input) - if args.enable_zbh1_pipeline: - from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore + + if ctx.bwd_stream is not None: + # for sp overlap communication + ctx.bwd_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(ctx.bwd_stream): + WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) + grad_weight = None + elif args.enable_zbh1_pipeline: WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - + if ctx.bwd_stream is not None: + total_input.record_stream(ctx.bwd_stream) + grad_output.record_stream(ctx.bwd_stream) if ctx.sequence_parallel: handle.wait() return sub_grad_input, grad_weight, grad_bias, None, None, None if ctx.async_grad_allreduce: handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def linear_with_grad_accumulation_and_async_allreduce( input: torch.Tensor, @@ -393,6 +403,7 @@ def linear_with_grad_accumulation_and_async_allreduce( gradient_accumulation_fusion: bool, async_grad_allreduce: bool, sequence_parallel: bool, + async_sp_all2all_stream=None ) -> torch.Tensor: """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. @@ -453,6 +464,7 @@ def linear_with_grad_accumulation_and_async_allreduce( gradient_accumulation_fusion, async_grad_allreduce, sequence_parallel, + async_sp_all2all_stream ] if not linear_with_grad_accumulation_and_async_allreduce.warned: @@ -607,7 +619,6 @@ def __init__(self, input_size, output_size, *, "cannot be enabled at the same time." ) - def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): @@ -706,9 +717,10 @@ def __init__(self, input_size: int, output_size: int, *, stride: int = 1, keep_master_weight_for_test: bool = False, skip_bias_add: bool = False, - moe=False, enable_expert_tensor_parallelism=False): + moe=False, enable_expert_tensor_parallelism=False, ds_sp_async_stream=None): torch.nn.Module.__init__(self) - + self.ds_sp_async_stream = ds_sp_async_stream + # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -784,6 +796,7 @@ def forward(self, input_): assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. + output_parallel = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.weight, @@ -791,6 +804,7 @@ def forward(self, input_): gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, + async_sp_all2all_stream=self.ds_sp_async_stream ) # All-reduce across all the partitions. diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 8956f8bb13..71d9c9d2e0 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -510,7 +510,14 @@ class ParallelAttention(MegatronModule): Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ - + sp_stream=None + + def get_sp_stream(self): + if not self.ds_sp_overlap: + return None + if ParallelAttention.sp_stream is None: + ParallelAttention.sp_stream=get_accelerator().Stream() + return ParallelAttention.sp_stream def __init__(self, config, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): @@ -524,7 +531,8 @@ def __init__(self, config, layer_number, self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.use_gqa = (self.num_attention_heads != self.num_key_value_heads) - + self.split_qkv = args.split_qkv_linear + self.ds_sp_overlap = args.ds_sequence_parallel_overlap_comm self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \ args.use_flash_attn_builder) \ and attention_type == AttnType.self_attn \ @@ -577,13 +585,31 @@ def __init__(self, config, layer_number, # Strided linear layer. if attention_type == AttnType.self_attn: - self.query_key_value = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - projection_size + 2 * kv_projection_size, - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) + if not self.split_qkv: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + + else: + linear_configs = [ + ("query_linear", projection_size), + ("key_linear", kv_projection_size), + ("value_linear", kv_projection_size), + ] + + for attr_name, output_size in linear_configs: + setattr(self, attr_name, tensor_parallel.ColumnParallelLinear( + config.hidden_size, + output_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False + )) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( @@ -614,12 +640,14 @@ def __init__(self, config, layer_number, self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ or args.force_ds_sequence_parallel if self.enable_ds_sequence_parallel: + assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 + self.dist_attn = DistributedAttention( local_attn, parallel_state.get_sequence_parallel_group(), - gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0) + gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0,sp_stream=self.get_sp_stream()) # flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension. else: if self.use_flash_attn: @@ -636,7 +664,9 @@ def __init__(self, config, layer_number, init_method=config.output_layer_init_method, bias=args.add_bias_linear, input_is_parallel=True, - skip_bias_add=True) + skip_bias_add=True, + ds_sp_async_stream=self.get_sp_stream() + ) def _checkpointed_attention_forward(self, query_layer, key_layer, @@ -722,22 +752,41 @@ def forward(self, hidden_states, attention_mask, # Query, Key, and Value # ===================== if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - if self.enable_ds_sequence_parallel: - assert self.projection_size == self.kv_projection_size - seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1] - query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim) - key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim) - value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim) - if self.sequence_parallel or not self.enable_ds_sequence_parallel: - seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1] - each_hidden_size = mixed_x_layer.shape[-1] // 3 - query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim) - key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim) - value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim) - + if not self.split_qkv: + # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + if self.enable_ds_sequence_parallel: + assert self.projection_size == self.kv_projection_size + seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1] + query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim) + key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim) + value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim) + if self.sequence_parallel or not self.enable_ds_sequence_parallel: + seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1] + each_hidden_size = mixed_x_layer.shape[-1] // 3 + query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim) + key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim) + value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim) + else: + assert self.ds_sp_overlap, """ + Currently, the split_qkv operation is only applicable + when ds_sp_overlap is enabled. + """ + self.get_sp_stream().wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(self.get_sp_stream()): + query_layer,_ = self.query_linear(hidden_states) + query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1) + fwd_query_layer_done_event = get_accelerator().Event() + fwd_query_layer_done_event.record(self.get_sp_stream()) + key_layer,_ = self.key_linear(hidden_states) + key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1) + + fwd_key_layer_done_event = get_accelerator().Event() + fwd_key_layer_done_event.record(self.get_sp_stream()) + value_layer,_ = self.value_linear(hidden_states) + value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1) + # Repeat kv if self.use_gqa: key_layer = self.repeat_kv(key_layer, self.num_key_value_groups) @@ -833,6 +882,9 @@ def forward(self, hidden_states, attention_mask, # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if self.enable_ds_sequence_parallel: + if self.ds_sp_overlap: + key_layer.done_event=fwd_key_layer_done_event + query_layer.done_event=fwd_query_layer_done_event batch_dim_idx = 1 if self.use_flash_attn: if not self.use_flash_attn_triton: