Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update FP8 scale-inverse in kernels with FP8 output #1083

Merged
merged 20 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
711b77e
Perform scale-inv update in cast-transpose kernels
timmoon10 Aug 1, 2024
c53bf0d
Perform scale-inv update in cast and activation kernels
timmoon10 Aug 2, 2024
0a12bff
Perform sclae-inv update in LayerNorm and RMSNorm kernels
timmoon10 Aug 5, 2024
20bed0b
Perform scale-inv update after FP8 GEMMs
timmoon10 Aug 5, 2024
f65b3d1
Fuse casts and scale-inv updates in linear module
timmoon10 Aug 6, 2024
fa65672
Fuse casts and scale-inv updates in layernorm-linear module
timmoon10 Aug 6, 2024
a3c00ec
Simplify kernel to update FP8 scale-inv
timmoon10 Aug 6, 2024
1182e3e
Fix typos
timmoon10 Aug 6, 2024
c548e8a
Debug amax update in layernorm kernels
timmoon10 Aug 7, 2024
a265eae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
a47500d
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 7, 2024
35b5a1b
Debug test failures
timmoon10 Aug 8, 2024
ad7b1f6
Debug ONNX export
timmoon10 Aug 9, 2024
9798cf8
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 9, 2024
e550929
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 12, 2024
dca23c5
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 14, 2024
4909a50
Review suggestion from @ptrendx
timmoon10 Aug 17, 2024
0d45a2e
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 17, 2024
66ff01d
Debug mismatched dtypes
timmoon10 Aug 19, 2024
87a4309
Merge branch 'main' into fuse-cast-and-scale-inv-update
timmoon10 Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose_dbias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ void performTest(const size_t N, const size_t H) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol, rtol] = getTolerances(otype);
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_layernorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/operator/test_multi_cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ void performTest() {
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
output_c_list[tensor_id].scale_inv(),
1.f / output_c_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_rmsnorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma)
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}

auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
Expand Down Expand Up @@ -141,4 +142,3 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")

# Install library
install(TARGETS transformer_engine DESTINATION .)

9 changes: 6 additions & 3 deletions transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
Expand All @@ -50,7 +51,8 @@ void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
Expand All @@ -74,7 +76,8 @@ void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
Expand Down
32 changes: 32 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <transformer_engine/transformer_engine.h>

#include "./common.h"
#include "./utils.cuh"

namespace transformer_engine {

namespace {

__global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale);
}

} // namespace

void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr));
}
}

} // namespace transformer_engine
7 changes: 7 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt

bool is_fp8_dtype(const DType t);

/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
* with the reciprocal of the FP8 scale (quantization scaling factor).
*/
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);

#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);

Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspace, /* workspace */
workspaceSize, stream)); /* stream */

// Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) {
update_tensor_scale_inv(outputD, stream);
}

NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ struct FwdParams : public ParamsBase {
// AMax output
void *amax;

// Inverse of scaling factor
void *scale_inv;

// Whether to compute scale and amax
bool fp8_out;
};
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;

Expand Down
32 changes: 24 additions & 8 deletions transformer_engine/common/layer_norm/ln_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down Expand Up @@ -291,10 +299,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne

// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.epsilon = epsilon;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;

Expand Down
32 changes: 24 additions & 8 deletions transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down Expand Up @@ -267,10 +275,18 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_

// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Reduce amax over block
if (params.amax != nullptr) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
}
}
}
Expand Down
28 changes: 18 additions & 10 deletions transformer_engine/common/transpose/cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,11 @@ struct KernelConfig {
};

template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size)
cast_transpose_general_kernel(const IType *__restrict__ const input,
const CType *__restrict__ const noop,
OType *__restrict__ const output_c,
OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr,
CType *__restrict__ const amax_ptr, const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
const IType *__restrict__ const input, const CType *__restrict__ const noop,
OType *__restrict__ const output_c, OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;

// Vectorized load/store sizes
Expand Down Expand Up @@ -207,9 +204,15 @@ __global__ void __launch_bounds__(block_size)
if (amax_ptr != nullptr) {
amax = reduce_max<warps_per_tile>(amax, tidy);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax_ptr, amax);
}
}

// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv_ptr != nullptr) {
reciprocal<CType>(scale_inv_ptr, scale);
}
}

} // namespace
Expand Down Expand Up @@ -255,6 +258,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
Expand Down Expand Up @@ -324,7 +329,9 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length,
num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
Expand All @@ -339,7 +346,8 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length, num_rows);
}); // NOLINT(*)
); // NOLINT(*)
}
Expand Down
Loading
Loading