Skip to content

Commit

Permalink
Merge branch 'main' into databricks/non-reentrant-checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Feb 21, 2024
2 parents 4b6fced + 2187a8f commit b1c4bb2
Show file tree
Hide file tree
Showing 17 changed files with 297 additions and 40 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.4.0dev
1.5.0dev
File renamed without changes.
File renamed without changes.
8 changes: 6 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq

all_boolean = [True, False]

all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu"]

all_normalizations = ["LayerNorm", "RMSNorm"]

Expand Down Expand Up @@ -310,12 +310,16 @@ def forward(self, x, attention_mask=None):
output = output[0]
return output

class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)

_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()}
'swiglu' : nn.SiLU(),
'qgelu' : TorchQuickGELU()}


class TorchGLU(nn.Module):
Expand Down
73 changes: 73 additions & 0 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,57 @@ void dgeglu(const Tensor &grad,
); // NOLINT(*)
}

void qgelu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "qgelu_input");
CheckOutputTensor(*output, "qgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
}

void dqgelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dqgelu_input");
CheckInputTensor(grad, "dqgelu_input_grad");
CheckOutputTensor(*output, "dqgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}

} // namespace transformer_engine

void nvte_gelu(const NVTETensor input,
Expand Down Expand Up @@ -172,3 +223,25 @@ void nvte_dgeglu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output),
stream);
}

void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
qgelu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}

void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dqgelu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
36 changes: 29 additions & 7 deletions transformer_engine/common/include/transformer_engine/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void nvte_dgelu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute GeGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
Expand Down Expand Up @@ -113,8 +113,8 @@ void nvte_dswiglu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute ReGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
Expand All @@ -123,9 +123,31 @@ void nvte_reglu(const NVTETensor input,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute QuickGELU activation of the input.
*
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute QuickGELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
Expand Down
8 changes: 0 additions & 8 deletions transformer_engine/common/layer_norm/ln.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,6 @@ struct BwdGeneralRegistrar{
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size);

//////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace layer_norm
Expand Down
57 changes: 40 additions & 17 deletions transformer_engine/common/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
************************************************************************/

#include <transformer_engine/layer_norm.h>

#include <cstdint>
#include <vector>

#include "ln.h"
#include "../common.h"

Expand Down Expand Up @@ -72,11 +75,20 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size,
uint32_t batch_size) {
const layer_norm::FwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.beta)
&& is_aligned(params.z)
&& layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::FWD_TUNED_FUNCS.at(tuned_key);
}
Expand All @@ -87,7 +99,7 @@ layer_norm::FwdFunction & get_fwd_launcher(DType wtype,
NVTE_ERROR("FWD: Unsupported types.");
}
auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
Expand All @@ -102,11 +114,24 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType itype,
DType otype,
DType ctype,
uint32_t hidden_size,
uint32_t batch_size) {
const layer_norm::BwdParams &params) {
// Look for tuned kernel
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, hidden_size);
if (batch_size % 4 == 0
auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols);
auto is_aligned = [](const void *ptr) -> bool {
// Assume vectorized memory accesses are <=16B
return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
};
if (params.rows % 4 == 0
&& is_aligned(params.x)
&& is_aligned(params.mu)
&& is_aligned(params.rs)
&& is_aligned(params.gamma)
&& is_aligned(params.dz)
&& is_aligned(params.dx)
&& is_aligned(params.dbeta)
&& is_aligned(params.dgamma)
&& is_aligned(params.dbeta_part)
&& is_aligned(params.dgamma_part)
&& layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return layer_norm::BWD_TUNED_FUNCS.at(tuned_key);
}
Expand All @@ -117,7 +142,7 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
NVTE_ERROR("BWD: Unsupported types.");
}
auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key);
auto func_iter = general_func_map.lower_bound(hidden_size);
auto func_iter = general_func_map.lower_bound(params.cols);
if (func_iter == general_func_map.end()) {
// Hidden size is too big, need to use multi-CTA
return general_func_map.rbegin()->second;
Expand Down Expand Up @@ -183,10 +208,6 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
launch_params.multiprocessorCount = multiprocessorCount;
launch_params.stream = stream;

// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);

// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
Expand All @@ -203,6 +224,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;

// Request the kernel launcher.
auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params);

// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (workspace->data.dptr == nullptr) {
Expand Down Expand Up @@ -304,9 +328,6 @@ void layernorm_bwd(const Tensor& dz,
launch_params.stream = stream;
launch_params.multiprocessorCount = multiprocessorCount;

auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype,
hidden_size, rows);

// Set the kernel runtime parameters.
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
Expand All @@ -323,6 +344,8 @@ void layernorm_bwd(const Tensor& dz,
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;

auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params);

// Query the kernel-specific launch parameters.
launcher(launch_params, true);

Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/common/util/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s);
}

template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
}

template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
const float cval = val;
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/common/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ struct Vec {
size_t idx = 0,
size_t count = NUM_ELT) {
const Elt_type *elt_ptr = static_cast<const Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) {
if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = (it < count
Expand All @@ -308,7 +309,8 @@ struct Vec {
size_t idx = 0,
size_t count = NUM_ELT) const {
Elt_type *elt_ptr = static_cast<Elt_type *>(base_ptr) + idx;
if ( count < NUM_ELT || idx % NUM_ELT != 0 ) {
if ( count < NUM_ELT
|| reinterpret_cast<uint64_t>(elt_ptr) % BYTES != 0 ) {
#pragma unroll
for ( int it = 0; it < NUM_ELT; it++ ) {
if ( it < count ) {
Expand Down
28 changes: 27 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformer_engine_extensions as tex


__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu']
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu']


def gelu(
Expand Down Expand Up @@ -140,3 +140,29 @@ def swiglu(
fp8_tensor,
otype,
)


def qgelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""QuickGELU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.qgelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
Loading

0 comments on commit b1c4bb2

Please sign in to comment.