Skip to content

Commit

Permalink
Merge branch 'main' into rachitg/DP_carveout
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Mar 12, 2024
2 parents ec88ba4 + a38b291 commit 3f0cf84
Showing 1 changed file with 22 additions and 35 deletions.
57 changes: 22 additions & 35 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ static cudnn_frontend::Tensor createDropoutForward(
double probability,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& beforeDropoutTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Dropout DAG constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(ops->size() > 0,
"Dropout DAG constructed incorrectly as the first one");

int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
Expand Down Expand Up @@ -421,9 +420,8 @@ static cudnn_frontend::Tensor createDropoutBackward(
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& beforeDropoutTensor,
const cudnn_frontend::Tensor& dropoutMaskTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Dropout DAG constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(ops->size() > 0,
"Dropout DAG constructed incorrectly as the first one");

int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
Expand Down Expand Up @@ -499,9 +497,8 @@ static cudnn_frontend::Tensor createSoftmaxBackward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
std::vector<cudnn_frontend::Operation>* ops,
const cudnn_frontend::Tensor& dyTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"Softmax backward constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(ops->size() > 0,
"Softmax backward constructed incorrectly as the first one");

int64_t dx_dim[4] = {b, h, s_q, s_kv};
int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1};
Expand Down Expand Up @@ -621,9 +618,8 @@ static cudnn_frontend::Tensor createSVBMM(
const cudnn_frontend::Tensor &softmaxTensor,
const cudnn_frontend::Tensor &mnkOverride,
std::shared_ptr<cudnn_frontend::Tensor> QKVRaggedOffsetTensor) {
cudnn_frontend::throw_if(ops->size() == 0,
"BMM2 op constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(ops->size() > 0,
"BMM2 op constructed incorrectly as the first one");

int64_t v_dim[4] = {b, h, s_kv, d};
int64_t v_stride[4];
Expand Down Expand Up @@ -669,9 +665,8 @@ static cudnn_frontend::Tensor createSdOBMM(
const cudnn_frontend::Tensor &softmaxTensor,
const cudnn_frontend::Tensor &dOTensor,
const cudnn_frontend::Tensor &mnkOverride) {
cudnn_frontend::throw_if(ops->size() == 0,
"BMM2 op constructed incorrectly as the first one",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(ops->size() > 0,
"BMM2 op constructed incorrectly as the first one");

int64_t s_dim_transpose[4] = {b, h, s_kv, s_q};
int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv};
Expand Down Expand Up @@ -1028,12 +1023,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;

cudnn_frontend::throw_if(dropoutProbability != 0.0f && !isTraining,
"Dropout probability should be 0.0f for inference mode",
CUDNN_STATUS_BAD_PARAM);
cudnn_frontend::throw_if(dropoutProbability == 1.0f,
"Dropout probability cannot be 1.0",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(dropoutProbability == 0.0f || isTraining,
"Dropout probability should be 0.0f for inference mode");
NVTE_CHECK(dropoutProbability != 1.0f,
"Dropout probability cannot be 1.0");

int64_t raggedDim[4] = {b + 1, 1, 1, 1};
int64_t raggedStride[4] = {1, 1, 1, 1};
Expand Down Expand Up @@ -1283,12 +1276,10 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
.setWorkspacePointer(workspace_ptr)
.setDataPointers(data_ptrs)
.build();
cudnnStatus_t status = cudnnBackendExecute(
handle_, plan.get_raw_desc(), variantPack.get_raw_desc());

cudnn_frontend::throw_if(
[status]() { return (status != CUDNN_STATUS_SUCCESS); },
"Plan execute error", status);
NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_,
plan.get_raw_desc(),
variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException& e) {
struct cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));
Expand Down Expand Up @@ -1347,9 +1338,8 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
std::vector<cudnn_frontend::Operation const*> all_ops;
std::vector<cudnn_frontend::Operation> ops;

cudnn_frontend::throw_if(dropoutProbability == 1.0f,
"Dropout probability cannot be 1.0",
CUDNN_STATUS_BAD_PARAM);
NVTE_CHECK(dropoutProbability != 1.0f,
"Dropout probability cannot be 1.0");

int64_t raggedDim[4] = {b + 1, 1, 1, 1};
int64_t raggedStride[4] = {1, 1, 1, 1};
Expand Down Expand Up @@ -1838,12 +1828,9 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
.setWorkspacePointer(workspace_ptr)
.setDataPointers(data_ptrs)
.build();
cudnnStatus_t status = cudnnBackendExecute(
handle_, plan.get_raw_desc(), variantPack.get_raw_desc());

cudnn_frontend::throw_if(
[status]() { return (status != CUDNN_STATUS_SUCCESS); },
"Plan execute error", status);
NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_,
plan.get_raw_desc(),
variantPack.get_raw_desc()));
} catch (cudnn_frontend::cudnnException& e) {
struct cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0));
Expand Down

0 comments on commit 3f0cf84

Please sign in to comment.