Skip to content

LayerNormLinear reset_parameters() leads to the wrong initialization. #2528

@jstjohn

Description

@jstjohn

Describe the bug

layer_norm_weight should be 1.0 at init time, and it is by default. However once you call reset_parameters() on it, it gets re-initialized to a small zero centered init. This would impact users who initialize on a meta device, or users of megatron FSDP.

Steps/Code to reproduce bug

(Pdb) from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
(Pdb) lnl = LayerNormLinear(10,10)
(Pdb) lnl.layer_norm_weight
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       requires_grad=True)
(Pdb) lnl.reset_parameters()
(Pdb) lnl.layer_norm_weight
Parameter containing:
tensor([-0.0124,  0.0245,  0.0220, -0.0057, -0.0060, -0.0016, -0.0040, -0.0054,
        -0.0262, -0.0147], device='cuda:0', requires_grad=True)

Expected behavior

The layer_norm_weight should be 1.0 after reset_parameters().

Environment overview (please complete the following information)

  • Environment location: Docker
  • Method of Transformer Engine install: FROM nvcr.io/nvidia/pytorch:25.11-py3
  • If method of install is [Docker], provide docker pull & docker run commands used
docker run --gpus=all -it nvcr.io/nvidia/pytorch:25.11-py3 bash
ipython
In [1]: from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
 
In [2]: lnl = LayerNormLinear(10,10)

In [3]: lnl.layer_norm_weight
Out[3]: 
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       requires_grad=True)

In [4]: lnl.reset_parameters()

In [5]: lnl.layer_norm_weight
Out[5]: 
Parameter containing:
tensor([ 0.0029, -0.0114, -0.0055, -0.0512, -0.0087,  0.0147, -0.0056,  0.0008,
        -0.0116,  0.0079], device='cuda:0', requires_grad=True)

In [6]: 

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version
  • Transformer Engine version
  • CUDA version
  • CUDNN version

Device details

  • GPU model

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions