Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] Add support for Apex RMSNorm for use in qk-norm #1261

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 128 additions & 3 deletions megatron/core/fusions/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
except ImportError:
HAVE_FUSED_LAYER_NORM = False

try:
from apex.normalization.fused_layer_norm import FusedRMSNormAffineFunction

HAVE_FUSED_RMS_NORM = True
except:
HAVE_FUSED_RMS_NORM = False


class FusedLayerNorm(torch.nn.Module):
"""Layer Norm, fused into a single CUDA kernel.
Expand Down Expand Up @@ -63,9 +70,13 @@ def __init__(
self.config = config

self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma
assert (
self.config.normalization == "LayerNorm"
), f'({self.config.normalization}) is not supported in FusedLayerNorm'

# If someone is trying to instantiate directly FusedLayerNorm but has
# specified another normalization in the config, raise an error
assert self.config.normalization == "LayerNorm", (
f'({self.config.normalization}) was specified in the config, but '
'FusedLayerNorm is trying to be instantiated here'
)

# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
Expand Down Expand Up @@ -167,3 +178,117 @@ def forward(self, input: Tensor) -> Tensor:
)

return output


class FusedRMSNorm(torch.nn.Module):
"""RMS Norm, fused into a single CUDA kernel. Note: so far there is no
persistent kernel for RMSNorm in apex, so we use a non-persistent one.

Args:
hidden_size (int): Transformer hidden dimension.

eps (float): Epsilon added to denominator, for numerical stability.

zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.

config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.

normalization (str): Normalization type, used for Transformer Engine.
Must equal 'RMSNorm' here.
"""

def __init__(
self,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
zero_centered_gamma: bool = False,
normalization: str = "RMSNorm", # included to match TE interface
):
super().__init__()

self.config = config

self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma

# If someone is trying to instantiate directly FusedRMSNorm but has
# specified another normalization in the config, raise an error
assert self.config.normalization == "RMSNorm", (
f'({self.config.normalization}) was specified in the config, but '
'FusedRMSNorm is trying to be instantiated here'
)

if not HAVE_FUSED_RMS_NORM:
raise ValueError(f'Apex must be installed to use FusedRMSNorm.')

if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self.weight = Parameter(torch.empty(*hidden_size))
self.reset_parameters()
self.sequence_parallel = self.config.sequence_parallel

# set sequence parallelism flag on weight parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)

def reset_parameters(self):

if self.zero_centered_gamma:
init.zeros_(self.weight)
else:
init.ones_(self.weight)

def forward(self, input: Tensor) -> Tensor:

weight = self.weight + 1 if self.zero_centered_gamma else self.weight

if (
'memory_efficient'
in inspect.getfullargspec(FusedRMSNormAffineFunction.forward).args
):
return FusedRMSNormAffineFunction.apply(
input,
weight,
self.hidden_size,
self.eps,
self.config.memory_efficient_layer_norm,
)
else:
return FusedRMSNormAffineFunction.apply(
input, weight, self.hidden_size, self.eps
)

class FusedApexNorm:
"""
A conditional wrapper to initialize an instance of Apex
`LayerNorm` or `RMSNorm` based on input.
"""
def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
):
if config.normalization == "LayerNorm":
instance = FusedLayerNorm(
config=config,
hidden_size=hidden_size,
eps=eps,
persist_layer_norm=config.persist_layer_norm,
zero_centered_gamma=config.layernorm_zero_centered_gamma
)
elif config.normalization == "RMSNorm":
instance = FusedRMSNorm(
config=config,
hidden_size=hidden_size,
eps=eps,
zero_centered_gamma=config.layernorm_zero_centered_gamma
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')

return instance
11 changes: 6 additions & 5 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
try:
import apex # pylint: disable=unused-import

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.fusions.fused_layer_norm import FusedApexNorm

HAVE_APEX = True
LNImpl = FusedLayerNorm
LNImpl = FusedApexNorm
except ImportError:
import warnings

Expand Down Expand Up @@ -110,9 +110,10 @@ def get_gpt_layer_with_transformer_engine_spec(
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
# TENorm significantly harms convergence when used
# for QKLayerNorm; we instead use the Apex implementation.
q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
# for QKLayerNorm; we instead use the Apex implementation (or pytorch
# one if Apex is not installed).
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that at a few other places in the code (here, here and here ), TENorm is still used for qk-normalization (even if according to the comment above and this commit, using TENorm for qk-layernorm is unstable).
Let me know if I should also modify these other places 👍

Copy link

@SeunghyunSEO SeunghyunSEO Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my case, i exactly did same patch for my own megatron fork.
so changes in this PR looks good to me, but i think we should clarify why it's happening?
like you said, someone still use tenorm for qknorm but model converges.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeunghyunSEO thanks!
Regarding the clarification on why this is happening, do you mean that we should check why the TE implementation is diverging ? (I didn't try it myself, I just assumed it does based on your PR and also based on the comment in this commit)

Copy link

@SeunghyunSEO SeunghyunSEO Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeunghyunSEO thanks!

Regarding the clarification on why this is happening, do you mean that we should check why the TE implementation is diverging ? (I didn't try it myself, I just assumed it does based on your PR and also based on the comment in this commit)

i mean when additional feature is added, at least we should know whether it is necessary or not.
any megatron or TE maintainers know why TEnorm for qk norm diverge sometimes??? i cc sir deepak because he is the only one i communicate with! @deepakn94 (sry for the wrong tagging but i ask you to tag expert in numerical precision issue)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense I agree 👍 Thanks for tagging @deepakn94 🙏
Also I think Mike Chrzanowski and Shanmugam Ramasamy can be tagged if Nvidia folks know their contact ? (because I couldn't find their github handle)
Because they are the one who created this commit which prevented the use of TENorm, and also Mike Chrzanowski wrote a paper using qk-layernorm 👍

k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
Expand Down