From 80686b7f56587be8823a41bece55069dc1d3ab84 Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 30 Jan 2024 13:43:39 +0800 Subject: [PATCH] Add THD format support for Context Parallel Signed-off-by: kunlunl --- qa/L1_pytorch_context_parallel_test/test.sh | 2 +- .../fused_attn/run_fused_attn_with_cp.py | 79 ++- .../fused_attn/test_fused_attn_with_cp.py | 2 +- transformer_engine/pytorch/attention.py | 135 +++- transformer_engine/pytorch/csrc/extensions.h | 42 ++ .../pytorch/csrc/extensions/attention.cu | 619 ++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 15 + 7 files changed, 846 insertions(+), 48 deletions(-) diff --git a/qa/L1_pytorch_context_parallel_test/test.sh b/qa/L1_pytorch_context_parallel_test/test.sh index 2c77a9d6c3..7f3c289b36 100644 --- a/qa/L1_pytorch_context_parallel_test/test.sh +++ b/qa/L1_pytorch_context_parallel_test/test.sh @@ -6,5 +6,5 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==6.2.5 onnxruntime==1.13.1 +pip install pytest==7.2.0 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 1af8391bce..2ee6754f9a 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention +import transformer_engine_extensions as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16} @@ -58,12 +59,27 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim) kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim) attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim) + cu_seqlens_q = None + cu_seqlens_kv = None elif qkv_format == "sbhd": q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim) kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim) attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim) + cu_seqlens_q = None + cu_seqlens_kv = None + elif qkv_format == "thd": + seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) + seqlens_q = seqlens_q - seqlens_q % (world_size * 2) + cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) + cu_seqlens_kv = cu_seqlens_q + q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim) + kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim) + attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim) + cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() + cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() else: assert False, f"{qkv_format} is an unsupported qkv_format!" + q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() @@ -79,6 +95,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= # make sure all GPU ranks have same inputs for x in [q, k, v, dout] + ([] if bias is None else [bias]): dist.broadcast(x, 0, group=cp_comm_group) + if qkv_format == "thd": + for x in [cu_seqlens_q, cu_seqlens_kv]: + dist.broadcast(x, 0, group=cp_comm_group) # run core_attn without CP for x in [q, k, v]: @@ -87,28 +106,48 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= q, k, v, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) out.backward(dout) # run core_attn wit CP - q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])] - bias_ = rest[0] if len(rest) else None - seq_dim = qkv_format.index('s') - q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ - for x in [q_, k_, v_, dout_]] - seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device) - q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] + if qkv_format == "bshd" or qkv_format == "sbhd": + q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])] + bias_ = rest[0] if len(rest) else None + seq_dim = qkv_format.index('s') + q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ + for x in [q_, k_, v_, dout_]] + seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device) + q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] + elif qkv_format == "thd": + q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]] + seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) + seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) + q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] + k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + cu_seqlens_q = cu_seqlens_q // world_size + cu_seqlens_kv = cu_seqlens_kv // world_size + bias_ = None + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1]) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) + max_seqlen_q = config.max_seqlen_q + max_seqlen_kv = config.max_seqlen_kv out_ = core_attn( q_, k_, v_, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, ) out_.backward(dout_) @@ -120,11 +159,20 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= tols = dict(atol=5e-3, rtol=5e-3) if dtype == 'bf16': tols = dict(atol=2.5e-2, rtol=2.5e-2) - dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ - for x in [q.grad, k.grad, v.grad, out]] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \ - for x in [q_.grad, k_.grad, v_.grad, out_]] + + if qkv_format == "bshd" or qkv_format == "sbhd": + dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ + for x in [q.grad, k.grad, v.grad, out]] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \ + for x in [q_.grad, k_.grad, v_.grad, out_]] + elif qkv_format == "thd": + dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] + dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + if qkv_format == "bshd": torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) @@ -143,6 +191,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= torch.testing.assert_close(dq_[1], dq[1], **tols) torch.testing.assert_close(dk_[1], dk[1], **tols) torch.testing.assert_close(dv_[1], dv[1], **tols) + elif qkv_format == "thd": + torch.testing.assert_close(out_, out, **tols) + torch.testing.assert_close(dq_, dq, **tols) + torch.testing.assert_close(dk_, dk, **tols) + torch.testing.assert_close(dv_, dv, **tols) else: assert False, f"{qkv_format} is an unsupported qkv_format!" diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 43280ecdde..e46859759f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -33,7 +33,7 @@ def get_bash_arguments(**kwargs): @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", ['bf16', 'fp16']) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd']) +@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd']) def test_cp_with_flash_attention(dtype, model, qkv_format): subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2f5a6aa671..a4246c97d2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -675,8 +675,13 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + if qkv_format == "thd": + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i%2] = tex.thd_read_half_tensor( + kv_inputs[i%2], cu_seqlens_k, 0) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] + kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -722,8 +727,13 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if len(rest) > 0: attn_biases[i] = rest[0] else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + if qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + else: + # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] + q_inputs[i%2] = \ + q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -781,7 +791,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if i == 1: out = torch.empty_like(q).zero_() softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) - if causal: + if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2 @@ -790,8 +800,14 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step[i-1]) else: - flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], - softmax_lse_per_step[i-1]) + if qkv_format == "thd": + tex.thd_second_half_lse_correction(softmax_lse, + softmax_lse_per_step[i-1], + cu_seqlens_q, + q.size(0)) + else: + flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], + softmax_lse_per_step[i-1]) if i < cp_size: flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done) @@ -799,7 +815,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) - seq_dim = qkv_format.index("s") + if qkv_format == "bshd" or qkv_format == "sbhd": + seq_dim = qkv_format.index("s") for i in range(cp_size): if qkv_format == "bshd": out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) @@ -807,18 +824,39 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, elif qkv_format == "sbhd": out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_ = out[1] + if i <= rank or not causal: - flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), - out_per_step[i], - seq_dim, - softmax_lse, - softmax_lse_per_step[i]) + if qkv_format == "bshd" or qkv_format == "sbhd": + flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), + out_per_step[i], + seq_dim, + softmax_lse, + softmax_lse_per_step[i]) + elif qkv_format == "thd": + tex.thd_out_correction(out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + False) + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" else: - flash_attn_fwd_out_correction(out_, - out_per_step[i], - seq_dim, - softmax_lse_[..., 1, :], - softmax_lse_per_step[i]) + if qkv_format == "bshd" or qkv_format == "sbhd": + flash_attn_fwd_out_correction(out_, + out_per_step[i], + seq_dim, + softmax_lse_[..., 1, :], + softmax_lse_per_step[i]) + elif qkv_format == "thd": + tex.thd_out_correction(out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + True) + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" kv = p2p_comm_buffers[-1] if use_fused_attention: @@ -828,6 +866,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, out = out.view(-1, *out.shape[-3:]) else: out = out.view(-1, *out.shape[-2:]) + ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) ctx.rng_states = rng_states ctx.cp_group = cp_group @@ -872,12 +911,17 @@ def backward(ctx, dout): attn_dbias = None if ctx.causal: - # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) - softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) + if ctx.qkv_format == "thd": + softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) + else: + # [b, np, sq] -> [b, np, 2, sq//2] + softmax_lse_ = \ + softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) + softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + if ctx.use_fused_attention: + # [b, np, sq//2] -> [b, np, sq//2, 1] + softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) @@ -1012,8 +1056,12 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) dq_ = torch.empty_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) + if ctx.qkv_format == "thd": + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) + else: + # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] + kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) @@ -1062,15 +1110,23 @@ def backward(ctx, dout): attn_bias_type=ctx.attn_bias_type, ) else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + if ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + else: + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] + q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) dq_ = torch.empty_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) - dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) + if ctx.qkv_format == "thd": + out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) + dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) + else: + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] + out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) + dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, -1] _flash_attn_backward( @@ -1143,16 +1199,22 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].copy_(dq_[0]) dq[1].add_(dq_[1]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add") elif i > 0: if ctx.qkv_format == "bshd": dq[:, 1, ...].add_(dq_) elif ctx.qkv_format == "sbhd": dq[1].add_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add") else: if ctx.qkv_format == "bshd": dq[:, 1, ...].copy_(dq_) elif ctx.qkv_format == "sbhd": dq[1].copy_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy") else: if i == 0: dq.copy_(dq_) @@ -1203,6 +1265,8 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") else: dkv.add_(dkv_) elif i >= (cp_size-rank-1): @@ -1211,11 +1275,15 @@ def backward(ctx, dout): dkv[:, :, 0, ...].copy_(dkv_) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none") else: if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].add_(dkv_) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none") elif i > 0: dkv.add_(dkv_) else: @@ -1253,10 +1321,12 @@ def attn_forward_func_with_cp( use_fused_attention=False ) -> torch.Tensor: """Attention implementation with context parallelism""" - assert(qkv_format in ["bshd", "sbhd"] + assert(qkv_format in ["bshd", "sbhd", "thd"] ), f"QKV format of {qkv_format} is not supported with context parallelism!" assert(qkv_format != "sbhd" or use_fused_attention ), "FlashAttention does not support sbhd format!" + assert(not(qkv_format == "thd" and use_fused_attention) + ), "FusedAttention does not support thd format!" assert (attn_mask_type in ["causal", "no_mask"] ), f"Mask type of {attn_mask_type} is not supported with context parallelism!" assert (attn_bias is None or use_fused_attention @@ -2050,7 +2120,6 @@ def forward( key_layer.device, ) elif qkv_format == 'thd': - assert not context_parallel, "thd format not supported with context parallelism!" assert (cu_seqlens_q is not None and cu_seqlens_kv is not None ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" if max_seqlen_q is None: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index abbecb1609..9713700c11 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -637,3 +637,45 @@ size_t get_cudnn_version(); bool userbuf_comm_available(); void placeholder(); + + +/*************************************************************************************************** + * Support THD format for Context Parallel + **************************************************************************************************/ + +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, + const at::Tensor &cu_seqlens, + int half_idx +); + +void thd_second_half_lse_correction(at::Tensor lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + int total_tokens +); + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, + const at::Tensor &cu_seqlens, + int total_tokens +); + +void thd_out_correction(at::Tensor out, + const at::Tensor &out_per_step, + const at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + bool only_second_half +); + +void thd_grad_correction(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, + const std::string &first_half, + const std::string &second_half +); + +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, + int total_tokens, + int world_size, + int rank +); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index cc747655c4..c3e00ceeae 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1440,3 +1440,622 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } + +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search + **************************************************************************************************/ + +__forceinline__ +__device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, + void *tensor, + int *cu_seqlens, + int batch, + int hidden_size_in_bytes, + int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset/2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4* cur_half_token = reinterpret_cast(reinterpret_cast(half) + \ + offset_in_bytes); + + offset_in_bytes = (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * \ + hidden_size_in_bytes; + float4* cur_token = reinterpret_cast(reinterpret_cast(tensor) + \ + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, + const at::Tensor &cu_seqlens, + int half_idx) { + NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + + // Shapes of q and dq are [t, h, d], so the dimension of "t" is 0 + // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = tensor.dim() == 3 ? 0 : 1; + + int batch = cu_seqlens.size(0) - 1; + int num_heads = tensor.size(seq_dim + 1); + int dim_per_head = tensor.size(seq_dim + 2); + int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); + + // For 128-bits load/store + NVTE_CHECK(hidden_size_in_bytes % 16 == 0); + + // Generate output + std::vector shape(tensor.dim()); + for (size_t i = 0; i < shape.size(); i++) { + shape[i] = tensor.size(i); + } + shape[seq_dim] /= 2; + at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type())); + + // Launch Kernel + constexpr unsigned int block = 256; + unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block; + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= tensor.size(i); + } + dim3 grid = {grid_x, grid_y}; + thd_read_half_tensor_kernel<<>>( + half.data_ptr(), + tensor.data_ptr(), + cu_seqlens.data_ptr(), + batch, + hidden_size_in_bytes, + half_idx, + tensor.size(seq_dim)); + + return half; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, + int batch, int num_heads, int max_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + size_t idx = row * max_seqlen + col + seq_len; + size_t half_idx = row * max_seqlen / 2 + col; + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +struct LseCorrectionFunctor { + __forceinline__ + __device__ static void run(double *lse, float *half_lse, size_t idx, size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +void thd_second_half_lse_correction(at::Tensor lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + int total_tokens) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(lse_per_step.dim() == 3); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch = lse.size(0); + int num_heads = lse.size(1); + int max_seqlen = lse.size(2); + + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + constexpr unsigned int block = 256; + unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + thd_lse_kernel<<>>( + lse.data_ptr(), + lse_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + max_seqlen); +} + +struct ReadLseFunctor { + __forceinline__ + __device__ static void run(float *lse, float *half_lse, size_t idx, size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, + const at::Tensor &cu_seqlens, + int total_tokens) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch = lse.size(0); + int num_heads = lse.size(1); + int max_seqlen = lse.size(2); + + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + std::vector shape = {batch, num_heads, max_seqlen / 2}; + at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); + + constexpr unsigned int block = 256; + unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + thd_lse_kernel<<>>( + lse.data_ptr(), + half_lse.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + max_seqlen); + + return half_lse; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, + dtype *out_per_step, + float *lse, + float *lse_per_step, + int *cu_seqlens, + int batch, + int num_heads, + int dim_per_head, + int max_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * max_seqlen + col + seq_len * only_second_half; + idx_per_step = row * max_seqlen / (only_second_half + 1) + col; + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += p_per_step[k] * lse_corrected_exp; + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +template +static void thd_out_correction_helper(at::Tensor out, + const at::Tensor &out_per_step, + const at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + + int total_tokens = out.size(0); + int num_heads = out.size(1); + int dim_per_head = out.size(2); + int batch = lse.size(0); + int max_seqlen = lse.size(2); + + NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); + NVTE_CHECK(out_per_step.size(1) == num_heads); + NVTE_CHECK(out_per_step.size(2) == dim_per_head); + NVTE_CHECK(lse.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + constexpr int tile = 16; + constexpr int block = 512; + unsigned int grid_x = (static_cast(total_tokens) / (only_second_half + 1) * \ + tile + block - 1) / block; + dim3 grid = {grid_x, (unsigned int)num_heads}; + + thd_out_correction_kernel<<>>( + out.data_ptr(), + out_per_step.data_ptr(), + lse.data_ptr(), + lse_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + dim_per_head, + max_seqlen); +} + +void thd_out_correction(at::Tensor out, + const at::Tensor &out_per_step, + const at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + bool only_second_half) { + if (only_second_half) { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } else { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +template +__global__ void thd_grad_correction_kernel(dtype *grad, + dtype *grad_per_step, + int *cu_seqlens, + int batch, + int hidden_size, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +struct EmptyFunctor { + __forceinline__ + __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ + __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ + __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + + #pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +static void thd_grad_correction_helper(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + // Shape of dq is [t, h, d], so the dimension of "t" is 0 + // Shape of dkv is [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = grad.dim() == 3 ? 0 : 1; + + int total_tokens = grad.size(seq_dim); + int num_heads = grad.size(seq_dim + 1); + int dim_per_head = grad.size(seq_dim + 2); + int batch = cu_seqlens.size(0) - 1; + + if constexpr (functor_idx < 2) { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2); + } else { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens); + } + NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads); + NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head); + + size_t hidden_size = num_heads * dim_per_head; + NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0); + + constexpr unsigned int block = 256; + unsigned int grid_x; + if constexpr (functor_idx < 2) { + grid_x = (total_tokens / 2 * 32 + block - 1) / block; + } else { + grid_x = (total_tokens * 32 + block - 1) / block; + } + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= grad.size(i); + } + dim3 grid = {grid_x, grid_y}; + + thd_grad_correction_kernel + <<>>( + grad.data_ptr(), + grad_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + hidden_size, + total_tokens); +} + +template +static void thd_grad_dispatcher(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, + const std::string &first_half, + const std::string &second_half) { + if (first_half == "add" && second_half == "none") { + thd_grad_correction_helper, EmptyFunctor, 0>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "copy" && second_half == "none") { + thd_grad_correction_helper( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "none" && second_half == "add") { + thd_grad_correction_helper, 1>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "none" && second_half == "copy") { + thd_grad_correction_helper( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "add" && second_half == "copy") { + thd_grad_correction_helper, CopyFunctor, 2>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "copy" && second_half == "add") { + thd_grad_correction_helper, 2>( + grad, grad_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported Functor of first half and second_half\n"); + } +} + +void thd_grad_correction(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, + const std::string &first_half, + const std::string &second_half) { + if (grad.scalar_type() == at::ScalarType::Half) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::BFloat16) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::Float) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else { + NVTE_ERROR("Unsupported dtype of grad\n"); + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +__global__ void thd_partition_indices_kernel(int *output, + int *cu_seqlens, + int batch, + int total_tokens, + int world_size, + int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size*2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len/2 ? rank : (world_size-1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, + int total_tokens, + int world_size, + int rank) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(rank >= 0 && rank < world_size); + NVTE_CHECK(world_size > 0); + NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); + + int batch = cu_seqlens.size(0) - 1; + + std::vector shape = {total_tokens / world_size}; + at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); + + constexpr unsigned int block = 256; + unsigned int grid = (output.size(0) + block - 1) / block; + thd_partition_indices_kernel<<>>( + output.data_ptr(), + cu_seqlens.data_ptr(), + batch, + total_tokens, + world_size, + rank); + + return output; +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4a7d51cada..b512fac203 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -102,6 +102,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); + // Support THD format for Context Parallel + m.def("thd_read_half_tensor", &thd_read_half_tensor, + "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " + "tensor"); + m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, + "Correct the second half of the softmax_lse"); + m.def("thd_read_second_half_lse", &thd_read_second_half_lse, + "Read the second half of the softmax_lse"); + m.def("thd_out_correction", &thd_out_correction, + "Correct the THD format output of context parallelism in forward pass"); + m.def("thd_grad_correction", &thd_grad_correction, + "Correct the THD format gradients of context parallelism in backward pass"); + m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, + "Generate partitioned indices for inputs in THD format"); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>())