Skip to content

Commit

Permalink
[PyTorch] Integration test for Megatron-LM (#1329)
Browse files Browse the repository at this point in the history
* Handle deprecated `hidden_size` arg in norm modules

Signed-off-by: Tim Moon <[email protected]>

* Support initializing norm ops on CPU

Signed-off-by: Tim Moon <[email protected]>

* Add integration test for Megatron-LM

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename Mcore integration test

Signed-off-by: Tim Moon <[email protected]>

* Handle case in RMSNorm where hidden dim is not provided

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] authored Nov 21, 2024
1 parent b495120 commit 6b98768
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 52 deletions.
58 changes: 58 additions & 0 deletions qa/L1_pytorch_mcore_integration/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}

# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi

# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0
python
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1
${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 2048
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt
--transformer-impl transformer_engine
--fp8-format hybrid
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')

# Launch Megatron-LM
bash -c "${COMMAND}"
19 changes: 18 additions & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
19 changes: 18 additions & 1 deletion transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
65 changes: 37 additions & 28 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape

Expand Down Expand Up @@ -84,28 +89,23 @@ def __init__(
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
dtype = canonicalize_dtype(dtype)
weight = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
bias = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
Expand Down Expand Up @@ -143,17 +143,18 @@ def getenv(name: str) -> int:
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
# Parameter device
weight = self.weight
bias = self.bias
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
else:
bias = bias.to(device=self.device)
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)

# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)
if not devices_match(bias.device, device):
bias = torch.empty_like(bias, device=device)

# Initialize values
if self.zero_centered_gamma:
Expand Down Expand Up @@ -184,17 +185,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
f"and weight tensor (shape={weight_dims}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -266,6 +271,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

Expand All @@ -282,9 +288,12 @@ def op_backward(
# Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors

# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)

# Check input tensors
inner_dim = x.size(-1)
device = self.device
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -312,6 +321,6 @@ def op_backward(

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_bias = reshape(db, self._shape)
grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, weight_dims)
return grad_input, (grad_weight, grad_bias)
53 changes: 32 additions & 21 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape

Expand Down Expand Up @@ -83,22 +88,17 @@ def __init__(
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
weight = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=canonicalize_dtype(dtype),
)
weight = torch.nn.Parameter(weight)
Expand Down Expand Up @@ -133,12 +133,15 @@ def getenv(name: str) -> int:
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
# Parameter device
weight = self.weight
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)

# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)

# Initialize values
if self.zero_centered_gamma:
Expand All @@ -165,17 +168,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
f"and weight tensor (shape={weight_dims}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
Expand Down Expand Up @@ -241,6 +248,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

Expand All @@ -257,9 +265,12 @@ def op_backward(
# Saved tensors from forward pass
x, rstdevs = ctx.saved_tensors

# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)

# Check input tensors
inner_dim = x.size(-1)
device = self.device
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
Expand All @@ -285,5 +296,5 @@ def op_backward(

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_weight = reshape(dw, weight_dims)
return grad_input, (grad_weight,)
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def forward(
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()

# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
Expand Down

0 comments on commit 6b98768

Please sign in to comment.