diff --git a/megatron/core/fusions/fused_layer_norm.py b/megatron/core/fusions/fused_layer_norm.py index d02ae7aa4d..4024638a56 100644 --- a/megatron/core/fusions/fused_layer_norm.py +++ b/megatron/core/fusions/fused_layer_norm.py @@ -26,6 +26,13 @@ except ImportError: HAVE_FUSED_LAYER_NORM = False +try: + from apex.normalization.fused_layer_norm import FusedRMSNormAffineFunction + + HAVE_FUSED_RMS_NORM = True +except: + HAVE_FUSED_RMS_NORM = False + class FusedLayerNorm(torch.nn.Module): """Layer Norm, fused into a single CUDA kernel. @@ -63,9 +70,13 @@ def __init__( self.config = config self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma - assert ( - self.config.normalization == "LayerNorm" - ), f'({self.config.normalization}) is not supported in FusedLayerNorm' + + # If someone is trying to instantiate directly FusedLayerNorm but has + # specified another normalization in the config, raise an error + assert self.config.normalization == "LayerNorm", ( + f'({self.config.normalization}) was specified in the config, but ' + 'FusedLayerNorm is trying to be instantiated here' + ) # List of hiddens sizes supported in the persistent layer norm kernel # If the hidden size is not supported, fall back to the non-persistent @@ -167,3 +178,117 @@ def forward(self, input: Tensor) -> Tensor: ) return output + + +class FusedRMSNorm(torch.nn.Module): + """RMS Norm, fused into a single CUDA kernel. Note: so far there is no + persistent kernel for RMSNorm in apex, so we use a non-persistent one. + + Args: + hidden_size (int): Transformer hidden dimension. + + eps (float): Epsilon added to denominator, for numerical stability. + + zero_centered_gamma (bool): Adjust LayerNorm weights such that they are + centered around zero. This improves numerical stability. + + config (TransformerConfig): Transformer config. Include to match custom + layer norm interfaces. + + normalization (str): Normalization type, used for Transformer Engine. + Must equal 'RMSNorm' here. + """ + + def __init__( + self, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + zero_centered_gamma: bool = False, + normalization: str = "RMSNorm", # included to match TE interface + ): + super().__init__() + + self.config = config + + self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma + + # If someone is trying to instantiate directly FusedRMSNorm but has + # specified another normalization in the config, raise an error + assert self.config.normalization == "RMSNorm", ( + f'({self.config.normalization}) was specified in the config, but ' + 'FusedRMSNorm is trying to be instantiated here' + ) + + if not HAVE_FUSED_RMS_NORM: + raise ValueError(f'Apex must be installed to use FusedRMSNorm.') + + if isinstance(hidden_size, numbers.Integral): + hidden_size = (hidden_size,) + self.hidden_size = torch.Size(hidden_size) + self.eps = eps + # Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2. + self.weight = Parameter(torch.empty(*hidden_size)) + self.reset_parameters() + self.sequence_parallel = self.config.sequence_parallel + + # set sequence parallelism flag on weight parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + + def reset_parameters(self): + + if self.zero_centered_gamma: + init.zeros_(self.weight) + else: + init.ones_(self.weight) + + def forward(self, input: Tensor) -> Tensor: + + weight = self.weight + 1 if self.zero_centered_gamma else self.weight + + if ( + 'memory_efficient' + in inspect.getfullargspec(FusedRMSNormAffineFunction.forward).args + ): + return FusedRMSNormAffineFunction.apply( + input, + weight, + self.hidden_size, + self.eps, + self.config.memory_efficient_layer_norm, + ) + else: + return FusedRMSNormAffineFunction.apply( + input, weight, self.hidden_size, self.eps + ) + +class FusedApexNorm: + """ + A conditional wrapper to initialize an instance of Apex + `LayerNorm` or `RMSNorm` based on input. + """ + def __new__( + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + ): + if config.normalization == "LayerNorm": + instance = FusedLayerNorm( + config=config, + hidden_size=hidden_size, + eps=eps, + persist_layer_norm=config.persist_layer_norm, + zero_centered_gamma=config.layernorm_zero_centered_gamma + ) + elif config.normalization == "RMSNorm": + instance = FusedRMSNorm( + config=config, + hidden_size=hidden_size, + eps=eps, + zero_centered_gamma=config.layernorm_zero_centered_gamma + ) + else: + raise Exception('Only LayerNorm and RMSNorm are curently supported') + + return instance diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 1db68dc886..ef4e95c91f 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -36,10 +36,10 @@ try: import apex # pylint: disable=unused-import - from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + from megatron.core.fusions.fused_layer_norm import FusedApexNorm HAVE_APEX = True - LNImpl = FusedLayerNorm + LNImpl = FusedApexNorm except ImportError: import warnings @@ -110,9 +110,10 @@ def get_gpt_layer_with_transformer_engine_spec( core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, # TENorm significantly harms convergence when used - # for QKLayerNorm; we instead use the Apex implementation. - q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp, - k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp, + # for QKLayerNorm; we instead use the Apex implementation (or pytorch + # one if Apex is not installed). + q_layernorm=LNImpl if qk_layernorm else IdentityOp, + k_layernorm=LNImpl if qk_layernorm else IdentityOp, ), ), self_attn_bda=get_bias_dropout_add,