Skip to content

Commit

Permalink
FP8 Support for MCore MoE (#648)
Browse files Browse the repository at this point in the history
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
  • Loading branch information
Victarry and ptrendx authored Apr 29, 2024
1 parent 9709147 commit 32d1eb1
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 24 deletions.
46 changes: 45 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import torch
import pytest

from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -107,6 +111,7 @@ def is_fp8_supported(self):
param_types.append(torch.bfloat16)

all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2]

all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
Expand Down Expand Up @@ -456,6 +461,45 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs*config.seq_len

if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")

use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params):
te_linear = (
Linear(
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype
)
.cuda()
)

inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_linear(inp_hidden_states)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
Expand Down
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ def cast_to_fp8(
"""Cast input to FP8"""

if out is not None:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
if inp.nelement() > 0:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
return None

return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
Expand All @@ -41,7 +43,6 @@ def cast_to_fp8(
otype,
)


def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input

args = (
A,
Expand Down Expand Up @@ -191,6 +193,8 @@ def gemm(
grad_bias = empty_tensor

bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input

assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
Expand Down
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ def fp8_cast_transpose_fused(
if noop_flag is None:
noop_flag = torch.Tensor()

tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)
if inp.nelement() > 0:
tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)

if return_outputs:
return cast_out, transpose_out
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input,

auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));

if (input.numel() == 0)
return output;

auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
grad_output.size(0),
DType::kByte);

if (M == 0 || N == 0)
return {grad_bias, grad_output_cast, grad_output_transpose};

auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
Expand Down Expand Up @@ -335,6 +338,8 @@ at::Tensor fp8_transpose(at::Tensor input,

size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
if (M == 0 || N == 0)
return input;

auto output =
allocateTorchTensor(input.size(1),
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,12 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
if get_rng_state_tracker is None:
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)
if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
with get_rng_state_tracker().fork(self.rng_tracker_name):
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)

# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def forward(
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat

if fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 forward')
Expand Down Expand Up @@ -664,6 +663,10 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
Expand Down Expand Up @@ -723,6 +726,7 @@ def __init__(
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
Expand Down Expand Up @@ -753,6 +757,8 @@ def __init__(
), "Userbuffer communication backend not available."
self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name

if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.")
Expand Down

0 comments on commit 32d1eb1

Please sign in to comment.