-
Notifications
You must be signed in to change notification settings - Fork 583
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
layer_norm_weight should be 1.0 at init time, and it is by default. However once you call reset_parameters() on it, it gets re-initialized to a small zero centered init. This would impact users who initialize on a meta device, or users of megatron FSDP.
Steps/Code to reproduce bug
(Pdb) from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
(Pdb) lnl = LayerNormLinear(10,10)
(Pdb) lnl.layer_norm_weight
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
requires_grad=True)
(Pdb) lnl.reset_parameters()
(Pdb) lnl.layer_norm_weight
Parameter containing:
tensor([-0.0124, 0.0245, 0.0220, -0.0057, -0.0060, -0.0016, -0.0040, -0.0054,
-0.0262, -0.0147], device='cuda:0', requires_grad=True)
Expected behavior
The layer_norm_weight should be 1.0 after reset_parameters().
Environment overview (please complete the following information)
- Environment location: Docker
- Method of Transformer Engine install:
FROM nvcr.io/nvidia/pytorch:25.11-py3 - If method of install is [Docker], provide
docker pull&docker runcommands used
docker run --gpus=all -it nvcr.io/nvidia/pytorch:25.11-py3 bash
ipython
In [1]: from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
In [2]: lnl = LayerNormLinear(10,10)
In [3]: lnl.layer_norm_weight
Out[3]:
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
requires_grad=True)
In [4]: lnl.reset_parameters()
In [5]: lnl.layer_norm_weight
Out[5]:
Parameter containing:
tensor([ 0.0029, -0.0114, -0.0055, -0.0512, -0.0087, 0.0147, -0.0056, 0.0008,
-0.0116, 0.0079], device='cuda:0', requires_grad=True)
In [6]:
Environment details
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version
- PyTorch version
- Python version
- Transformer Engine version
- CUDA version
- CUDNN version
Device details
- GPU model
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working