diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index cbbb5a8168..d1369f59a3 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -67,8 +67,7 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim) cu_seqlens_q = None cu_seqlens_kv = None - else: - assert(config.max_seqlen_q % (world_size * 2) == 0) + elif qkv_format == "thd": seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32) seqlens_q = seqlens_q - seqlens_q % (world_size * 2) cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)]) @@ -78,6 +77,8 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim) cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda() cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda() + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() @@ -94,7 +95,7 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn(q, k, v, cu_seqlens_q = cu_seqlens_q, cu_seqlens_kv = cu_seqlens_kv) + out = core_attn(q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv) out.backward(dout) # run core_attn wit CP @@ -106,20 +107,21 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device) q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] - else: + elif qkv_format == "thd": q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]] seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank) seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + cu_seqlens_q = cu_seqlens_q // world_size + cu_seqlens_kv = cu_seqlens_kv // world_size + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) - if qkv_format == "bshd" or qkv_format == "sbhd": - out_ = core_attn(q_, k_, v_) - else: - out_ = core_attn(q_, k_, v_, cu_seqlens_q=cu_seqlens_q//world_size, cu_seqlens_kv=cu_seqlens_kv//world_size) + out_ = core_attn(q_, k_, v_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv) out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: @@ -137,10 +139,12 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \ for x in [q_.grad, k_.grad, v_.grad, out_]] - else: + elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]] dq_, dk_, dv_, out_ = [x.view(-1) for x in [q_.grad, k_.grad, v_.grad, out_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" if qkv_format == "bshd": torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) @@ -160,11 +164,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= torch.testing.assert_close(dq_[1], dq[1], **tols) torch.testing.assert_close(dk_[1], dk[1], **tols) torch.testing.assert_close(dv_[1], dv[1], **tols) - else: + elif qkv_format == "thd": torch.testing.assert_close(out_, out, **tols) torch.testing.assert_close(dq_, dq, **tols) torch.testing.assert_close(dk_, dk, **tols) torch.testing.assert_close(dv_, dv, **tols) + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" def main(**kwargs): run_dpa_with_cp(**kwargs) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c7ed629aaa..2401047522 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -624,7 +624,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_inputs[i%2] = q.view(-1, *q.shape[-2:]) if thd: # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i%2] = tex.thd_read_half_tensor(kv_inputs[i%2], cu_seqlens_k, 0) + kv_inputs[i%2] = tex.thd_read_half_tensor( + kv_inputs[i%2], cu_seqlens_k, 0) else: # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() @@ -660,8 +661,9 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, # [t, np, hn] -> [t/2, np, hn] q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] + q_inputs[i%2] = \ + q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -720,10 +722,10 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, softmax_lse_per_step[i-1]) else: if thd: - tex.thd_lse_correction(softmax_lse, - softmax_lse_per_step[i-1], - cu_seqlens_q, - q.size(0)) + tex.thd_second_half_lse_correction(softmax_lse, + softmax_lse_per_step[i-1], + cu_seqlens_q, + q.size(0)) else: flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], softmax_lse_per_step[i-1]) @@ -801,10 +803,11 @@ def backward(ctx, dout): if ctx.causal: if ctx.thd: - softmax_lse_ = tex.thd_read_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) + softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) else: # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) + softmax_lse_ = \ + softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] @@ -922,7 +925,7 @@ def backward(ctx, dout): # [2, t, np, hn] -> [2, t/2, np, hn] kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] + # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 26a2ca3b04..1217754937 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -648,15 +648,15 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, int half_idx ); -void thd_lse_correction(at::Tensor &lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, - int total_tokens +void thd_second_half_lse_correction(at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + int total_tokens ); -at::Tensor thd_read_half_lse(const at::Tensor &lse, - const at::Tensor &cu_seqlens, - int total_tokens +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, + const at::Tensor &cu_seqlens, + int total_tokens ); void thd_out_correction(at::Tensor &out, @@ -664,7 +664,7 @@ void thd_out_correction(at::Tensor &out, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool is_half + bool only_second_half ); void thd_grad_correction(at::Tensor &grad, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index ade4345264..dd2a515227 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1463,8 +1463,13 @@ __device__ int binary_search(int target, int *array, int len) { * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, - int batch, int hidden_size_in_bytes, int half_idx) { +__global__ void thd_read_half_tensor_kernel(void *half, + void *tensor, + int *cu_seqlens, + int batch, + int hidden_size_in_bytes, + int half_idx, + int dim_size_of_token) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { cu_seqlens_s[i] = cu_seqlens[i] / 2; @@ -1477,9 +1482,9 @@ __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_se int num_total_tokens = cu_seqlens_s[batch]; int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - size_t offset = num_total_tokens * (size_t)hidden_size_in_bytes; - half = (void*)((char*)half + offset * blockIdx.y); - tensor = (void*)((char*)tensor + 2 * offset * blockIdx.y); + size_t offset = (size_t)dim_size_of_token * hidden_size_in_bytes; + half = (void*)((char*)half + offset/2 * blockIdx.y); + tensor = (void*)((char*)tensor + offset * blockIdx.y); for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); @@ -1511,8 +1516,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, int batch = cu_seqlens.size(0) - 1; int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); - size_t hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); + int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); + // For 128-bits load/store NVTE_CHECK(hidden_size_in_bytes % 16 == 0); // Generate output @@ -1538,7 +1544,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, cu_seqlens.data_ptr(), batch, hidden_size_in_bytes, - half_idx); + half_idx, + tensor.size(seq_dim)); return half; } @@ -1563,12 +1570,12 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - int row = seq_id * num_heads + head_id; + size_t row = (size_t)seq_id * num_heads + head_id; int col = token_id - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - size_t idx = (size_t)row * max_seqlen + col + seq_len; - size_t half_idx = (size_t)row * max_seqlen / 2 + col; + size_t idx = row * max_seqlen + col + seq_len; + size_t half_idx = row * max_seqlen / 2 + col; Functor::run(lse, half_lse, idx, half_idx); } @@ -1586,10 +1593,10 @@ struct LseCorrectionFunctor { } }; -void thd_lse_correction(at::Tensor &lse, - const at::Tensor &lse_per_step, - const at::Tensor &cu_seqlens, - int total_tokens) { +void thd_second_half_lse_correction(at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + int total_tokens) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -1628,9 +1635,9 @@ struct ReadLseFunctor { } }; -at::Tensor thd_read_half_lse(const at::Tensor &lse, - const at::Tensor &cu_seqlens, - int total_tokens) { +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, + const at::Tensor &cu_seqlens, + int total_tokens) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -1665,7 +1672,7 @@ at::Tensor thd_read_half_lse(const at::Tensor &lse, * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template +template __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, @@ -1677,7 +1684,7 @@ __global__ void thd_out_correction_kernel(dtype *out, int max_seqlen) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / (is_half + 1); + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); } __syncthreads(); @@ -1692,14 +1699,14 @@ __global__ void thd_out_correction_kernel(dtype *out, for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { size_t idx, idx_per_step; - int row = seq_id * num_heads + head_id; + size_t row = (size_t)seq_id * num_heads + head_id; int col = token_id - cu_seqlens_s[seq_id]; int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = (size_t)row * max_seqlen + col + seq_len * is_half; - idx_per_step = (size_t)row * max_seqlen / (is_half + 1) + col; + idx = row * max_seqlen + col + seq_len * only_second_half; + idx_per_step = row * max_seqlen / (only_second_half + 1) + col; float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); - idx = (size_t)token_id + cu_seqlens_s[seq_id + 1] * is_half; + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; idx = (idx * num_heads + head_id) * dim_per_head; idx_per_step = ((size_t)token_id * num_heads + head_id) * dim_per_head; dtype *cur_out = out + idx; @@ -1710,8 +1717,8 @@ __global__ void thd_out_correction_kernel(dtype *out, float4 data = ((float4*)cur_out)[j]; dtype *p_per_step = (dtype*)&data_per_step; dtype *p = (dtype*)&data; - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p[i] += p_per_step[i] * lse_corrected_exp; + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += p_per_step[k] * lse_corrected_exp; } ((float4*)cur_out)[j] = data; } @@ -1719,7 +1726,7 @@ __global__ void thd_out_correction_kernel(dtype *out, } } -template +template static void thd_out_correction_helper(at::Tensor &out, const at::Tensor &out_per_step, const at::Tensor &lse, @@ -1736,22 +1743,22 @@ static void thd_out_correction_helper(at::Tensor &out, int batch = lse.size(0); int max_seqlen = lse.size(2); - NVTE_CHECK(out_per_step.size(0) == total_tokens / (is_half + 1)); + NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(2) == dim_per_head); NVTE_CHECK(lse.size(1) == num_heads); NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (is_half + 1)); + NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); constexpr int tile = 16; constexpr int block = 512; - unsigned int grid_x = min((total_tokens / (is_half + 1) * tile + block - 1) / block, 256); + unsigned int grid_x = ((size_t)total_tokens / (only_second_half + 1) * tile + block - 1) / block; dim3 grid = {grid_x, (unsigned int)num_heads}; - thd_out_correction_kernel<<>>( + thd_out_correction_kernel<<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), @@ -1768,8 +1775,8 @@ void thd_out_correction(at::Tensor &out, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, - bool is_half) { - if (is_half) { + bool only_second_half) { + if (only_second_half) { if (out.scalar_type() == at::ScalarType::Half) { using dtype = at::Half; thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); @@ -1803,8 +1810,12 @@ void thd_out_correction(at::Tensor &out, **************************************************************************************************/ template -__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, - int batch, int hidden_size) { +__global__ void thd_grad_correction_kernel(dtype *grad, + dtype *grad_per_step, + int *cu_seqlens, + int batch, + int hidden_size, + int dim_size_of_token) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { if constexpr (functor_idx < 2) { @@ -1821,13 +1832,13 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in int num_total_tokens = cu_seqlens_s[batch]; int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); - size_t offset = num_total_tokens * (size_t)hidden_size; - grad_per_step = grad_per_step + offset * blockIdx.y; + size_t offset = (size_t)dim_size_of_token * hidden_size; if constexpr (functor_idx < 2) { - grad = grad + offset * blockIdx.y * 2; + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; } else { - grad = grad + offset * blockIdx.y; + grad_per_step = grad_per_step + offset * blockIdx.y; } + grad = grad + offset * blockIdx.y; for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); @@ -1857,8 +1868,7 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in struct EmptyFunctor { __forceinline__ - __device__ static void run(void *token, void *token_per_step, int idx) { - } + __device__ static void run(void *token, void *token_per_step, int idx) {} }; struct CopyFunctor { @@ -1872,14 +1882,18 @@ template struct AddFunctor { __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d = ((float4*)token)[idx]; - dtype *p = (dtype*)(&d); - float4 d_ = ((float4*)token_per_step)[idx]; + float4 d_ = ((float4*)token)[idx]; dtype *p_ = (dtype*)(&d_); + + float4 d = ((float4*)token_per_step)[idx]; + dtype *p = (dtype*)(&d); + + #pragma unroll for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p[i] += p_[i]; + p_[i] += p[i]; } - ((float4*)token)[idx] = d; + + ((float4*)token)[idx] = d_; } }; @@ -1930,7 +1944,8 @@ static void thd_grad_correction_helper(at::Tensor &grad, grad_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, - hidden_size); + hidden_size, + total_tokens); } template @@ -1990,7 +2005,11 @@ __global__ void thd_partition_indices_kernel(int *output, int rank) { extern __shared__ int cu_seqlens_s[]; for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / world_size; + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size*2) == 0); + cu_seqlens_s[i] = seqlen / world_size; } __syncthreads(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 9110d96e63..b512fac203 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -103,14 +103,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); // Support THD format for Context Parallel - m.def("thd_read_half_tensor", &thd_read_half_tensor, "Read the half of a THD tensor"); - m.def("thd_lse_correction", &thd_lse_correction, "softmax_lse correction for THD format"); - m.def("thd_read_half_lse", &thd_read_half_lse, "Read the half of the softmax_lse"); - m.def("thd_out_correction", &thd_out_correction, "Out correction for THD format"); - m.def("thd_grad_correction", &thd_grad_correction, "Gradients correction for THD format"); - m.def("thd_get_partitioned_indices", - &thd_get_partitioned_indices, - "Generate partitioned indices for input tokens"); + m.def("thd_read_half_tensor", &thd_read_half_tensor, + "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " + "tensor"); + m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, + "Correct the second half of the softmax_lse"); + m.def("thd_read_second_half_lse", &thd_read_second_half_lse, + "Read the second half of the softmax_lse"); + m.def("thd_out_correction", &thd_out_correction, + "Correct the THD format output of context parallelism in forward pass"); + m.def("thd_grad_correction", &thd_grad_correction, + "Correct the THD format gradients of context parallelism in backward pass"); + m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, + "Generate partitioned indices for inputs in THD format"); // Data structures py::class_(m, "FP8TensorMeta")