Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Support for MCore MoE #648

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import torch
import pytest

from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (
fp8_autocast,
FP8GlobalStateManager,
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -107,6 +111,7 @@ def is_fp8_supported(self):
param_types.append(torch.bfloat16)

all_boolean = [True, False]
batch_sizes_with_zero = [0, 1, 2]

all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
Expand Down Expand Up @@ -456,6 +461,45 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs*config.seq_len

if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")

use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params):
te_linear = (
Linear(
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype
)
.cuda()
)

inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_linear(inp_hidden_states)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"])
Expand Down
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ def cast_to_fp8(
"""Cast input to FP8"""

if out is not None:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
if inp.nelement() > 0:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
return None

return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
Expand All @@ -41,7 +43,6 @@ def cast_to_fp8(
otype,
)


def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def fp8_gemm(
bias_dtype = TE_DType[bias_dtype]

out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input

args = (
A,
Expand Down Expand Up @@ -191,6 +193,8 @@ def gemm(
grad_bias = empty_tensor

bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, grad_bias, gelu_input

assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
Expand Down
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ def fp8_cast_transpose_fused(
if noop_flag is None:
noop_flag = torch.Tensor()

tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)
if inp.nelement() > 0:
tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
cast_out,
transpose_out,
otype,
)

if return_outputs:
return cast_out, transpose_out
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input,

auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));

if (input.numel() == 0)
return output;

auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
grad_output.size(0),
DType::kByte);

if (M == 0 || N == 0)
return {grad_bias, grad_output_cast, grad_output_transpose};

auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
Expand Down Expand Up @@ -335,6 +338,8 @@ at::Tensor fp8_transpose(at::Tensor input,

size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
if (M == 0 || N == 0)
return input;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Victarry This will cause shape mismatch error between wgrad and weight when gradient accumulation fusion is disabled.


auto output =
allocateTorchTensor(input.size(1),
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,12 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
if get_rng_state_tracker is None:
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)
if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
with get_rng_state_tracker().fork(self.rng_tracker_name):
init_fn(param)
else:
with get_rng_state_tracker().fork():
init_fn(param)

# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def forward(
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat

if fp8:
if _NVTE_DEBUG:
print('[Linear]: using FP8 forward')
Expand Down Expand Up @@ -664,6 +663,10 @@ class Linear(TransformerEngineBaseModule):
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initilizeing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
Configuration for splitting the weight and bias tensors along dim 0 into
multiple PyTorch parameters. If a list or tuple of strings is provided,
Expand Down Expand Up @@ -723,6 +726,7 @@ def __init__(
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
Expand Down Expand Up @@ -753,6 +757,8 @@ def __init__(
), "Userbuffer communication backend not available."
self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name

if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.")
Expand Down
Loading