Skip to content

Commit

Permalink
[JAX] Consolidate FFI and old descriptor implementation for fused att…
Browse files Browse the repository at this point in the history
…ention. (#1295)

Consolidate FFI and old descriptor impleemntation for fused attention.

Signed-off-by: Michael Goldfarb <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
mgoldfarb-nvidia and phu0ngng authored Oct 30, 2024
1 parent ed1e85c commit c036765
Showing 1 changed file with 72 additions and 198 deletions.
270 changes: 72 additions & 198 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,46 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];

/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];

/* Descriptor */
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;

/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
Expand All @@ -247,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
cudaMemsetAsync(output, 0,
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
}

auto q_cu_seqlens_tensor =
Expand Down Expand Up @@ -281,43 +252,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
backend, softmax_aux);

/* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);

/* Call the underly NVTE API */
/* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
Expand All @@ -333,6 +298,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_tensor_pack_destroy(&aux_output_tensors);
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;

/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];

/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];

FusedAttnForwardImpl(
stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets,
k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}

Error_Type FusedAttnForwardFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Expand All @@ -344,147 +340,25 @@ Error_Type FusedAttnForwardFFI(
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
/* Descriptor data type conversion */
size_t input_batch = static_cast<size_t>(input_batch_);
size_t bias_batch = static_cast<size_t>(bias_batch_);
size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
size_t attn_heads = static_cast<size_t>(attn_heads_);
size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
size_t bias_heads = static_cast<size_t>(bias_heads_);
size_t head_dim = static_cast<size_t>(head_dim_);
size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
size_t wkspace_size = static_cast<size_t>(wkspace_size_);
float scaling_factor = static_cast<float>(scaling_factor_);
float dropout_probability = static_cast<float>(dropout_probability_);
NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
DType dtype = static_cast<DType>(dtype_);
DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

/* Input buffers from XLA */
/* q, k, v are parsed later for different qkv_layout */
void *bias = bias_buf.untyped_data();
void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data();
void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data();
void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr;
void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr;
void *seed = seed_buf.untyped_data();

/* Output buffer from XLA */
void *output = output_buf->untyped_data();
void *softmax_aux = softmax_aux_buf->untyped_data();
void *rng_state = rng_state_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();

/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
}

auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);

/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);

/* cuDNN workspace */
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);

/* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = q_buf.untyped_data();
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = q_buf.untyped_data();
auto kv = k_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = q_buf.untyped_data();
auto k = k_buf.untyped_data();
auto v = v_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}

nvte_tensor_pack_destroy(&aux_output_tensors);
FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
workspace_buf->untyped_data(), static_cast<size_t>(input_batch_),
static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_),
static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_),
static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_),
static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_),
static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_),
static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_),
static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_),
static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic,
window_size_left, window_size_right);

return ffi_with_cuda_error_check();
}
Expand Down

0 comments on commit c036765

Please sign in to comment.