Skip to content

Commit

Permalink
minor fixed based on PR review
Browse files Browse the repository at this point in the history
Signed-off-by: xiaoyao0115 <[email protected]>
  • Loading branch information
xiaoyao0115 committed Dec 24, 2024
1 parent d346d9c commit 18ad713
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 63 deletions.
6 changes: 3 additions & 3 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
}

/***************************************************************************************************
* Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward
* Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward
**************************************************************************************************/

// format of out and lse, ignoring d as it’s always the last dimension.
Expand Down Expand Up @@ -258,12 +258,12 @@ __global__ void fused_out_correction_kernel(dtype *out, TensorList<max_tensors>
}

for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
size_t idx_out;
size_t idx_lse;
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p = reinterpret_cast<dtype *>(&data);

for (int i = start; i < end; i++) {
size_t idx_out;
size_t idx_lse;
if (id[1] >= 0 && start + tensors.start_tensor_this_launch > full_num && i > rank) {
idx_out = idx_out_half;
idx_lse = idx_lse_half;
Expand Down
51 changes: 12 additions & 39 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,34 +2421,18 @@ def forward(

softmax_lse = softmax_lse.to(torch.float)

if qkv_format == "sbhd":
tex.fused_out_correction(
out.view(-1, *out.shape[-3:]),
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
elif qkv_format == "bshd":
tex.fused_out_correction(
out.view(out.shape[-4], -1, *out.shape[-2:]),
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
else:
tex.fused_out_correction(
if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
kv = p2p_comm_buffers[-1]
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
ctx.batch_size = out.shape[0]
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
ctx.batch_size = out.shape[1]

tex.fused_out_correction(
out,
out_per_step,
softmax_lse,
Expand All @@ -2461,17 +2445,6 @@ def forward(
softmax_lse_in_packed_format,
)

if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1)
kv = p2p_comm_buffers[-1]
if qkv_format == "bshd":
out = out.view(out.shape[0], -1, *out.shape[-2:])
ctx.batch_size = out.shape[0]
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
ctx.batch_size = out.shape[1]

if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
out = flash_attn_a2a_communicate(
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ std::vector<at::Tensor> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

/***************************************************************************************************
* GEMM
**************************************************************************************************/
Expand Down Expand Up @@ -457,6 +452,11 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

/***************************************************************************************************
* multi_tensor_* kernels
Expand Down
19 changes: 7 additions & 12 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1223,15 +1223,14 @@ std::vector<at::Tensor> fused_attn_bwd(
}

/***************************************************************************************************
* Support THD(including SBHD and BSHD) format for Context Parallel: Fused out correction in forward
* Support BSHD, SBHD, and THD formats for Context Parallel: Fused out correction in forward
**************************************************************************************************/

template <typename dtype, bool causal>
void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool softmax_lse_in_packed_format,
const at::Tensor *lse_ = nullptr) {
int rank, bool softmax_lse_in_packed_format) {
int lse_seqlen;
int batch;
int num_heads;
Expand All @@ -1257,7 +1256,7 @@ void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &
if (softmax_lse_in_packed_format) {
lse_seqlen = total_tokens;
} else {
lse_seqlen = lse.size(2);
lse_seqlen = lse.size(2);
}
}
constexpr int tile = 16;
Expand All @@ -1278,19 +1277,15 @@ void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &
tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr<float>();
}
if (qkv_format == "sbhd") {
if (softmax_lse_in_packed_format) {
fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::HBS,
max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);
} else {

NVTE_CHECK(softmax_lse_in_packed_format == false, "Packed lse doesn't support SBHD format.");

fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::BHS,
max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);
}


} else if (qkv_format == "bshd") {
if (softmax_lse_in_packed_format) {
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());

m.def("fused_out_correction", &fused_out_correction,
"fused out correction after qkv calculation without lse_",
py::call_guard<py::gil_scoped_release>());

// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
Expand Down Expand Up @@ -199,6 +195,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_out_correction", &fused_out_correction,
"fused out correction after qkv calculation without lse_",
py::call_guard<py::gil_scoped_release>());

// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
Expand Down

0 comments on commit 18ad713

Please sign in to comment.