Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 24, 2024
1 parent 18ad713 commit ed446e7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
24 changes: 12 additions & 12 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2431,19 +2431,19 @@ def forward(
elif qkv_format == "sbhd":
out = out.view(-1, *out.shape[-3:])
ctx.batch_size = out.shape[1]

tex.fused_out_correction(
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)

if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
Expand Down
13 changes: 5 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &
if (softmax_lse_in_packed_format) {
lse_seqlen = total_tokens;
} else {
lse_seqlen = lse.size(2);
lse_seqlen = lse.size(2);
}
}
constexpr int tile = 16;
Expand All @@ -1277,15 +1277,12 @@ void fused_out_correction_helper(at::Tensor out, const std::vector<at::Tensor> &
tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr<float>();
}
if (qkv_format == "sbhd") {

NVTE_CHECK(softmax_lse_in_packed_format == false, "Packed lse doesn't support SBHD format.");

fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::BHS,
max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);

fused_out_correction_kernel<dtype, tile, causal, QKVFormat::SBH, QKVFormat::BHS, max_tensors>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), tensors, lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i);

} else if (qkv_format == "bshd") {
if (softmax_lse_in_packed_format) {
Expand Down

0 comments on commit ed446e7

Please sign in to comment.