Skip to content

v2.10

Latest

Choose a tag to compare

@ksivaman ksivaman released this 11 Dec 21:29

Release v2.10

Key Features and Enhancements

  • [PyTorch] Added support for the NVFP4 training recipe for the GroupedLinear module.
  • [PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
  • [PyTorch] Added support for CUDA graphs when using delay_wgrad_compute.
  • [PyTorch] Expanded debug tools to support more statistics.
  • [PyTorch] Reduced the overhead of using debug tools.
  • [PyTorch] Added support for clamped SwiGLU in the TransformerLayer module.
  • [PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a keep_columnwise parameter to cast_master_weights_to_fp8 and related helper functions.
  • [PyTorch] Added a reset interface to make_graphed_callables that clears internal CUDA graphs before distributed process group cleanup, preventing hangs.
  • [PyTorch] Added support for FSDP2 with quantized weights.
  • [PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
  • [PyTorch] Integrated Flash Attention's num_splits parameters into the attention backend.
  • [PyTorch] Made various improvements to mitigate CPU overhead, especially for the GroupedLinear module.
  • [C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that start_positions could only be used with cp_size=1 (context parallelism disabled).
  • [Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
  • [Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
  • [Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
  • [Jax] Added support for checkpointing quantization operations in JAX.
  • [Jax] Added support for sink attention.
  • [Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).

Fixed Issues

  • Fixed an occasional crash when loading cuDNN library during runtime.
  • [C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
  • [C] Fixed a numerical error in the amax computation in normalization kernels.
  • [PyTorch] Fixed a crash in the permute kernel when using triton v3.5.
  • [PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
  • [PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
  • [Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
  • [Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
  • [Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
  • [PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

  • [Jax] Default value for intermediate_dropout changed from 0.1 to 0.0.
  • [Jax] Default value for return_layernorm_output changed from True to False.
  • [Jax] Default activation changed from ReLU to GeLU.
  • [Jax] Default input type for DotProductAttention is changed to BSHD.

Deprecated Features

No features are deprecated in this release.