-
Notifications
You must be signed in to change notification settings - Fork 583
Description
Describe the bug
Linear.bias is initialized to a zero vector. After a call to reset_parameters() rather than setting state to what it should be at init time, it gets set to random init. This is the same flavor of bug as #2528 but much less serious since the mean is still the same, and at least with typical initializations the standard deviation is near zero. It would be cleaner to do it correctly though.
Steps/Code to reproduce bug
docker run --gpus=all -it nvcr.io/nvidia/pytorch:25.11-py3 bash
ipython
In [1]: import transformer_engine as te
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:64: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
In [2]: tel = te.pytorch.Linear(3,2)
In [3]: tel.bias
Out[3]:
Parameter containing:
tensor([0., 0.], device='cuda:0', requires_grad=True)
In [4]: tel.weight
Out[4]:
Parameter containing:
tensor([[-0.0293, -0.0364, 0.0211],
[-0.0061, -0.0232, 0.0171]], device='cuda:0', requires_grad=True)
In [5]: tel.reset_parameters()
In [6]: tel.bias
Out[6]:
Parameter containing:
tensor([-0.0540, 0.0662], device='cuda:0', requires_grad=True)
Expected behavior
bias should be reset to a zero vector after reset_parameters().
Environment overview (please complete the following information)
- Environment location: Docker
- Method of Transformer Engine install: pre-installed in nvidia pytorch docker image
- If method of install is [Docker], provide
docker pull&docker runcommands used
(see steps to repro bug)
Environment details
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version
- PyTorch version
- Python version
- Transformer Engine version
- CUDA version
- CUDNN version
Device details
- GPU model
Additional context
Add any other context about the problem here.