Skip to content

v2.9

Choose a tag to compare

@ptrendx ptrendx released this 11 Nov 01:38
· 138 commits to main since this release

Release v2.9

Key Features and Enhancements

  • [PyTorch][Jax] Introduced recipe agnostic functions and APIs in order to generalize to non-FP8 recipes. See Deprecated Features section for a comprehensive list of affected APIs.
  • [C][PyTorch][Jax] Added support for the clamped SwiGLU activation function.
  • [C] Added support for precompiled wheels for cuda13 via PyPI.
  • [PyTorch] Added support for custom training recipes in the autocast context. Transformer Engine quantizers, quantized tensors classes as well as storage dataclasses are now a part of the public API.
  • [PyTorch] Added CPU offload support for all attention layouts.
  • [PyTorch] Added support for the FP8 block scaling recipe (as used in the DeepSeek v3 Technical Report) on NVIDIA Blackwell architecture (SM100 family).
  • [PyTorch] Added support for gradient accumulation fusion when using FSDP.
  • [PyTorch] Added support for CPU offloading when using GroupedLinear with distributed optimizer.
  • [PyTorch] Exposed as public API utility functions: is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, is_nvfp4_available, is_bf16_available, get_cudnn_version, get_device_compute_capability, and get_default_recipe.
  • [PyTorch] Added max_logit support for the MuonClip optimizer.
  • [PyTorch][Jax] Improved the logic for selecting the attention backend, addressing various unsupported cases and preventing errors.
  • [Jax] Added support for the NVFP4 training recipe.
  • [Jax] Improved the performance of the current scaling recipes by enabling fused amax calculation in normalization and activation kernels.
  • [Jax] Added support for bottom right causal mask for THD attention.
  • Improved documentation and tutorials for the NVFP4 recipe.

Fixed Issues

  • [Jax] Fixed a crash when using Context Parallelism with ring attention.
  • [Jax] Fixed an issue with incorrect sharding when get_all_mesh_axes is used.
  • [Jax] Fixed a numerical issue when using bias along with Tensor Parallelism.
  • [PyTorch] Fixed an integer overflow issue in the triton permute kernel.
  • [PyTorch] Fixed the known issue from release_v2.8 which resulted in worse performance for the FP8 current scaling recipe.
  • Fixed a build issue when cuDNN is installed into a custom location or python virtual environment.

Known Issues in This Release

  • [C][PyTorch] The cuDNN attention backend produces NaNs in the forward pass for cases using a non-causal mask with cuDNN 9.13 and cuDNN 9.14. As a workaround, please set the NVTE_FUSED_ATTN environment variable to 0 when using this configuration.
  • [C][PyTorch] The backward pass of cuDNN attention is incompatible with CUDA graphs for BSHD inputs where the sequence (S) dimension is not divisible by 128 when used with a non-padding mask. As a workaround, please set the NVTE_FUSED_ATTN environment variable to 0 when using this configuration.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [PyTorch] The function fp8_autocast is deprecated in favor of autocast. The new autocast function uses arguments recipe and amax_reduction_group instead of fp8_recipe and fp8_group respectively.
    [PyTorch] The function fp8_model_init is deprecated in favor of quantized_model_init.
    [PyTorch] The arguments fp8_enabled, fp8_calibrating, fp8_recipe, fp8_group, and fp8_weight_caching in function make_graphed_callables are deprecated in favor of enabled, calibrating, recipe, amax_reduction_group, and cache_quantized_params respectively.
  • [Jax] The function fp8_autocast is deprecated in favor of autocast.

Miscellaneous:
None