Skip to content

Commit

Permalink
Removed the unused options from GroupedLinear docs and fixed the bug …
Browse files Browse the repository at this point in the history
…with offsets (#1220)

* Removing the unused options from GroupedLinear docs and fixing the bug
with offsets

Signed-off-by: Przemyslaw Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* offsets -> fp8_meta_offsets

Signed-off-by: Przemyslaw Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ptrendx and pre-commit-ci[bot] committed Oct 1, 2024
1 parent 458c7de commit 4df8488
Showing 1 changed file with 27 additions and 64 deletions.
91 changes: 27 additions & 64 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@

__all__ = ["GroupedLinear"]

"""
The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in GroupedLinear's initialization.
"""
_GEMM_INPUT = 0
_GEMM_WEIGHT = 0
_GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0


class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear semi-top level module
Expand All @@ -74,12 +62,9 @@ def forward(
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
fp8_meta_offsets: Dict[str, int],
is_grad_enabled: bool,
weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
Expand All @@ -103,7 +88,6 @@ def forward(
inputmats_t = []
inputmat_scale_inv = None

global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
Expand All @@ -114,7 +98,9 @@ def forward(
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms))
indices = list(
range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms)
)
inputmats, inputmats_t = fp8_multi_cast_transpose_fused(
inputmats_no_fp8,
fp8_meta["scaling_fwd"],
Expand All @@ -130,7 +116,7 @@ def forward(
cast_to_fp8(
inputmats_no_fp8[i],
fp8_meta["scaling_fwd"],
_GEMM_INPUT + i,
fp8_meta_offsets["input"] + i,
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
Expand Down Expand Up @@ -194,14 +180,14 @@ def forward(
for i in range(num_gemms):
# amax of input
amin, amax = inputmats[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max(
-amin, amax
).float()
fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = (
torch.max(-amin, amax).float()
)
# amax of weight
amin, amax = weights[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max(
-amin, amax
).float()
fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = (
torch.max(-amin, amax).float()
)

out = torch.empty(
[sum(m_splits), weights[0].size(0)],
Expand Down Expand Up @@ -266,11 +252,8 @@ def forward(
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.fp8_meta_offsets = fp8_meta_offsets
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
Expand Down Expand Up @@ -300,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
w.main_grad = main_grads[i]
weights[i] = w

global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
# preprocess grad_output
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
Expand All @@ -318,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
fp8_cast_transpose_bgrad_fused(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
ctx.fp8_meta_offsets["grad_output"] + i,
fp8_dtype_backward,
)
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms))
indices = list(
range(
ctx.fp8_meta_offsets["grad_output"],
ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms,
)
)
grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused(
grad_output_mats,
ctx.fp8_meta["scaling_bwd"],
Expand All @@ -338,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
grad_output_c[i] = cast_to_fp8(
grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i,
ctx.fp8_meta_offsets["grad_output"] + i,
fp8_dtype_backward,
)

Expand All @@ -363,7 +350,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
ctx.fp8_meta_offsets["grad_output"],
fp8_dtype_backward,
[dgrad],
ctx.activation_dtype,
Expand Down Expand Up @@ -416,7 +403,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT,
ctx.fp8_meta_offsets["grad_output"],
fp8_dtype_backward,
wgrad_list,
ctx.activation_dtype,
Expand Down Expand Up @@ -497,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad):
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # fp8_meta_offsets
None, # is_grad_enabled
None, # weights_fp8
*wgrad_list,
Expand Down Expand Up @@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule):
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
Expand Down Expand Up @@ -613,8 +580,7 @@ def __init__(
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name

global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}

if tp_group is None:
self.tp_size = tp_size
Expand Down Expand Up @@ -651,7 +617,7 @@ def __init__(
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=_GEMM_WEIGHT + i,
fp8_meta_index=self._offsets["weight"] + i,
)

# Construct bias parameters if needed
Expand Down Expand Up @@ -774,7 +740,7 @@ def forward(
weight_tensors_fp8[i] = self.get_fp8_workspace(
tensor=weight_tensors[i],
fp8_meta_forward=True,
fp8_meta_index=_GEMM_WEIGHT + i,
fp8_meta_index=self._offsets["weight"] + i,
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
Expand All @@ -798,12 +764,9 @@ def forward(
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self._offsets,
torch.is_grad_enabled(),
weight_tensors_fp8,
*weight_tensors,
Expand Down

0 comments on commit 4df8488

Please sign in to comment.