Release v2.10
Key Features and Enhancements
- [PyTorch] Added support for the NVFP4 training recipe for the
GroupedLinearmodule. - [PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
- [PyTorch] Added support for CUDA graphs when using
delay_wgrad_compute. - [PyTorch] Expanded debug tools to support more statistics.
- [PyTorch] Reduced the overhead of using debug tools.
- [PyTorch] Added support for clamped SwiGLU in the
TransformerLayermodule. - [PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a
keep_columnwiseparameter tocast_master_weights_to_fp8and related helper functions. - [PyTorch] Added a
resetinterface tomake_graphed_callablesthat clears internal CUDA graphs before distributed process group cleanup, preventing hangs. - [PyTorch] Added support for FSDP2 with quantized weights.
- [PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
- [PyTorch] Integrated Flash Attention's
num_splitsparameters into the attention backend. - [PyTorch] Made various improvements to mitigate CPU overhead, especially for the
GroupedLinearmodule. - [C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that
start_positionscould only be used withcp_size=1(context parallelism disabled). - [Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
- [Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
- [Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
- [Jax] Added support for checkpointing quantization operations in JAX.
- [Jax] Added support for sink attention.
- [Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).
Fixed Issues
- Fixed an occasional crash when loading cuDNN library during runtime.
- [C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
- [C] Fixed a numerical error in the amax computation in normalization kernels.
- [PyTorch] Fixed a crash in the permute kernel when using
tritonv3.5. - [PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
- [PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
- [Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
- [Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
- [Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
- [PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
- [Jax] Default value for
intermediate_dropoutchanged from 0.1 to 0.0. - [Jax] Default value for
return_layernorm_outputchanged fromTruetoFalse. - [Jax] Default activation changed from ReLU to GeLU.
- [Jax] Default input type for
DotProductAttentionis changed to BSHD.
Deprecated Features
No features are deprecated in this release.