Releases: NVIDIA/TransformerEngine
v2.10
Release v2.10
Key Features and Enhancements
- [PyTorch] Added support for the NVFP4 training recipe for the
GroupedLinearmodule. - [PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
- [PyTorch] Added support for CUDA graphs when using
delay_wgrad_compute. - [PyTorch] Expanded debug tools to support more statistics.
- [PyTorch] Reduced the overhead of using debug tools.
- [PyTorch] Added support for clamped SwiGLU in the
TransformerLayermodule. - [PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a
keep_columnwiseparameter tocast_master_weights_to_fp8and related helper functions. - [PyTorch] Added a
resetinterface tomake_graphed_callablesthat clears internal CUDA graphs before distributed process group cleanup, preventing hangs. - [PyTorch] Added support for FSDP2 with quantized weights.
- [PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
- [PyTorch] Integrated Flash Attention's
num_splitsparameters into the attention backend. - [PyTorch] Made various improvements to mitigate CPU overhead, especially for the
GroupedLinearmodule. - [C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that
start_positionscould only be used withcp_size=1(context parallelism disabled). - [Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
- [Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
- [Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
- [Jax] Added support for checkpointing quantization operations in JAX.
- [Jax] Added support for sink attention.
- [Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).
Fixed Issues
- Fixed an occasional crash when loading cuDNN library during runtime.
- [C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
- [C] Fixed a numerical error in the amax computation in normalization kernels.
- [PyTorch] Fixed a crash in the permute kernel when using
tritonv3.5. - [PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
- [PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
- [Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
- [Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
- [Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
- [PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
- [Jax] Default value for
intermediate_dropoutchanged from 0.1 to 0.0. - [Jax] Default value for
return_layernorm_outputchanged fromTruetoFalse. - [Jax] Default activation changed from ReLU to GeLU.
- [Jax] Default input type for
DotProductAttentionis changed to BSHD.
Deprecated Features
No features are deprecated in this release.
v2.9
Release v2.9
Key Features and Enhancements
- [PyTorch][Jax] Introduced recipe agnostic functions and APIs in order to generalize to non-FP8 recipes. See Deprecated Features section for a comprehensive list of affected APIs.
- [C][PyTorch][Jax] Added support for the clamped SwiGLU activation function.
- [C] Added support for precompiled wheels for cuda13 via PyPI.
- [PyTorch] Added support for custom training recipes in the
autocastcontext. Transformer Engine quantizers, quantized tensors classes as well as storage dataclasses are now a part of the public API. - [PyTorch] Added CPU offload support for all attention layouts.
- [PyTorch] Added support for the FP8 block scaling recipe (as used in the DeepSeek v3 Technical Report) on NVIDIA Blackwell architecture (SM100 family).
- [PyTorch] Added support for gradient accumulation fusion when using FSDP.
- [PyTorch] Added support for CPU offloading when using
GroupedLinearwith distributed optimizer. - [PyTorch] Exposed as public API utility functions:
is_fp8_available,is_mxfp8_available,is_fp8_block_scaling_available,is_nvfp4_available,is_bf16_available,get_cudnn_version,get_device_compute_capability, andget_default_recipe. - [PyTorch] Added
max_logitsupport for the MuonClip optimizer. - [PyTorch][Jax] Improved the logic for selecting the attention backend, addressing various unsupported cases and preventing errors.
- [Jax] Added support for the NVFP4 training recipe.
- [Jax] Improved the performance of the current scaling recipes by enabling fused amax calculation in normalization and activation kernels.
- [Jax] Added support for bottom right causal mask for THD attention.
- Improved documentation and tutorials for the NVFP4 recipe.
Fixed Issues
- [Jax] Fixed a crash when using Context Parallelism with ring attention.
- [Jax] Fixed an issue with incorrect sharding when
get_all_mesh_axesis used. - [Jax] Fixed a numerical issue when using bias along with Tensor Parallelism.
- [PyTorch] Fixed an integer overflow issue in the triton permute kernel.
- [PyTorch] Fixed the known issue from
release_v2.8which resulted in worse performance for the FP8 current scaling recipe. - Fixed a build issue when cuDNN is installed into a custom location or python virtual environment.
Known Issues in This Release
- [C][PyTorch] The cuDNN attention backend produces NaNs in the forward pass for cases using a non-causal mask with cuDNN 9.13 and cuDNN 9.14. As a workaround, please set the
NVTE_FUSED_ATTNenvironment variable to 0 when using this configuration. - [C][PyTorch] The backward pass of cuDNN attention is incompatible with CUDA graphs for BSHD inputs where the sequence (S) dimension is not divisible by 128 when used with a non-padding mask. As a workaround, please set the
NVTE_FUSED_ATTNenvironment variable to 0 when using this configuration.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
- [PyTorch] The function
fp8_autocastis deprecated in favor ofautocast. The newautocastfunction uses argumentsrecipeandamax_reduction_groupinstead offp8_recipeandfp8_grouprespectively.
[PyTorch] The functionfp8_model_initis deprecated in favor ofquantized_model_init.
[PyTorch] The argumentsfp8_enabled,fp8_calibrating,fp8_recipe,fp8_group, andfp8_weight_cachingin functionmake_graphed_callablesare deprecated in favor ofenabled,calibrating,recipe,amax_reduction_group, andcache_quantized_paramsrespectively. - [Jax] The function
fp8_autocastis deprecated in favor ofautocast.
Miscellaneous:
None
v2.8
Release Notes – Release 2.8
Key Features and Enhancements
- [C][PyTorch] Added support for the NVFP4 training recipe.
- [C][PyTorch] Added support for FP8 attention with the current scaling recipe.
- [PyTorch] Added support for mixing recipes for different modules when using the make_graphed_callables function.
- [C] Added 8-bit RNG support to the dropout kernel.
- [C][PyTorch] Added the nvte_rmsnorm_bwd_add function to the C API and added support for fusing RMSNorm and add operation in the sequential Transformer Engine operations API.
- [C] Added more robust error checking and handling when calling CUDA and driver APIs.
- [C][PyTorch] Added support for using FP8 and non-FP8 quantization modes in the same model when overlapping tensor parallel communication and GEMM using userbuffers.
- [PyTorch] Added support for the qgeglu and sreglu activation in the Transformer Engine fused operations API and the LayerNormMLP module.
- [C][PyTorch] Added support for FP8 GEMM output for the MXFP8 and current scaling recipe.
- [PyTorch] Added support for FP8 all-gather when using Tensor Parallel with the GroupedLinear module.
- [PyTorch] Added support for the current scaling FP8 recipe for module export via ONNX.
- [PyTorch] Made miscellaneous improvements to MoE workloads to reduce CPU overhead.
- [PyTorch] Improved performance of CUDA graphs using FP8 weight cache in quantization kernels.
- [PyTorch] Added support for FlashAttention v3 for MLA with context parallelism.
- [PyTorch] Added support activation CPU offloading for Transformer Engine sequential operations API.
- [PyTorch] Made miscellaneous performance improvements when using RoPE (rotary positional embeddings)
- [C] Added support for BF16 and FP32 inputs to the kernel that calculates auxiliary loss for MoE.
- [C] Added support for sink attention from cuDNN.
- [Jax] Fused swizzling operation for the scaling factor inverse and transpose calculation of the data.
Fixed Issues
- [Jax] Fixed a crash when the user calls global_shard_guard before setting the JAX mesh.
- [Jax] Fixed an issue in the mesh logic such that if an axis is undefined in the mesh, Transformer Engine still applies the sharding constraint for the given tensor on other axes instead of skipping.
- [Jax] Fixed a crash in GroupedScaledTensor due to incorrect arguments being passed.
- [PyTorch] Fixed a bug in the cross entropy loss kernel that resulted in vanishing gradients.
- [C][PyTorch] Fixed incorrect calculation of tensor parallel rank when using userbuffers.
- [PyTorch] Fixed redundant memory overheads when using FP8 all-gather with sequence parallelism.
Known Issues in This Release
- [PyTorch] For distributed workloads using the Float8CurrentScaling recipe without FP8 attention, there are some performance overheads due to redundant amax reductions across the tensor parallel and context parallel groups.
This issue has been fixed (#2234), and will be available in the next release (v2.9).
As a workaround, you can run the workload with export NVTE_DPA_FP8_RECIPE="F16" in the environment.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
There are no deprecated features in this release.
v2.7
Release v2.7
Key Features and Enhancements
- [PyTorch] Added support for applying LayerNorm and RMSNorm to key and query tensors.
- [PyTorch] Improved performance for FP8 per tensor current scaling recipe by fusing amax computation into activation kernel.
- [PyTorch] Added support for multi-tensor swizzle kernels for MXFP8 grouped GEMMs.
- [PyTorch] Fused zero-padding and swizzle operation for MXFP8 scale inverses for improved performance.
- [PyTorch] Expanded the debug API using
nvdlfw-inpectin order to log more advanced tensor statistics. - [PyTorch] Reduced the number of calls to CUDA driver for improved performance of the core library.
- [Jax] Added new checkpointing policies that allow users to switch to TE GEMMs seamlessly without unnecessary recomputations.
- [Core] Added support for cublasMP backend for overlapping TP communication and GEMM.
Fixed Issues
- [PyTorch].Fixed a potential illegal memory access when using TP overlap.
- [PyTorch] Fixed the logic for choosing the correct attention backend depending on the cuDNN version.
- [PyTorch] Fixed a crash when using CUDA graphs by disabling garbage collection during capture.
- [PyTorch] Fixed a bug when using double buffering for CPU offloading.
- [PyTorch] Fixed a bug when overlapping gradient reduction and fusing weight gradient accumulation simultaneously.
- [PyTorch] Made multiple improvements and fixes to TE sequential API, including expanding supported operations to cover dropout, constant scale, etc.
- [PyTorch] Fixed a bug in the
make_graphed_callablesfunction when applied to multiple modules with different input requirements. - [PyTorch] Fixed the crash in the permute operation when running with the FP8 datatype for input sizes requiring padding.
- [PyTorch] Fixed a bug when using the Triton cross entropy kernel with cuda graphs.
- [PyTorch] Fixed a bug when exporting an MXFP8 model to ONNX.
- [PyTorch/Core] Disabled cuDNN attention backend for cuDNN v9.12 onwards on blackwell if the user requests a deterministic config.
- [Core] Fixed integer overflow in quantization kernels when computing offsets for large tensors.
- [Jax] Fixed partition rules for GEMM to correctly handle sequence parallelism.
- [Jax] Fixed sharding specs for TE GEMM custom call operands when using DP.
- [Jax] Fixed a crash when using
GroupedQuantizeFFIwith cuda graphs - [Jax] Fixed the fused_attn sharding constraint so that it can be used under the JAX shard_map..
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
The deprecated device_id argument for multi tensor C APIs has been removed.
Deprecated Features
There are no deprecated features in this release.
v2.6
Release Notes – Release 2.6
Key Features and Enhancements
- [PyTorch] Added support for gradient accumulation fusion when using FSDP from megatron-core.
- [PyTorch] Optimized memory usage when using NVIDIA® CUDA® graphs with TE using the make_graphed_callables function.
- [PyTorch] Optimized performance of permute fusion kernels for MoE.
- [PyTorch] Added support for ONNX export of Transformer Engine modules.
- [PyTorch] Added a save_original_input option to the Linear and GroupedLinear modules to decouple row-wise (forward) and column-wise (backward) quantization. This option saves memory for certain workloads and training recipes.
- [PyTorch] Improved performance of MXFP8 quantization kernels.
- [Core] Improved performance of KV caching kernels.
Fixed Issues
- [PyTorch] Fixed an issue in the LayerNormLinear module where the returned normalization output was of different shape than the input tensor.
- [PyTorch] Fixed an issue with the align_size calculation in FP8 padding/unpadding modules.
- [PyTorch] Made miscellaneous fixes and enhancements to the fusible ops (te.sequential) API.
- [PyTorch] Reduced CPU overhead in various workloads: DelayedScaling recipe, MXFP8 MoE, and pipeline parallelism.
- [PyTorch] Fixed a bug in the multi-tensor adam kernel that incorrectly downcast an FP32 tensor to BF16.
- [PyTorch] Fixed an issue with caching FP8 weights when running validation steps between training steps.
- [PyTorch] Fixed a logical error that could lead to using an suboptimal attention backend when a better-performing backend is available.
- [PyTorch] Fixed miscellaneous errors during runtime loading of shared libraries by expanding search paths.
- [PyTorch] Fixed a “use after-free” in cases where quantization and normalization are unfused.
- [Jax] Fixed a crash with grouped GEMM in CUDA version ≥ 12.9.1.
- [Jax] Fixed build with JAX v0.7.0 that failed due to removal of jax.extend.ffi.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
There are no deprecated features in this release.
Miscellaneous
There are no miscellaneous issues in this release.
v2.5
Release Notes – Release 2.5
Key Features and Enhancements
- Added support for Python 3.12+
- Added support for head dimension (head_dim) > 128 for attention for all architectures.
- [Jax] Added support for sliding window attention (SWA) in context parallel ring attention using THD format and striped sharding.
- [Jax] Improved performance for per-tensor scaling FP8 recipe.
- [Jax] Added MXFP8 support for the GroupedDense module and handle the case with zero input tokens.
- [PyTorch] Enabled FP8 tensor-parallel communication for FP8 block scaling recipe for Hopper by supporting coalesced gather of FP8 quantized tensors.
- [PyTorch] Optimized MXFP8 Userbuffers implementation by overlapping wgrad NCCL all-gather with dgrad GEMM..
- [PyTorch] Added support for CPU offloading when using FP8 parameters.
- [PyTorch] Added support for Context Parallel for Multi Latent Attention (MLA).
- [PyTorch] Reduced CPU overhead in MoE.
- [C][PyTorch] Improved performance for FP8 padding and unpadding kernels for MoE.
- [PyTorch] Added support for FP8 current scaling in operation-based API.
Fixed Issues
- [Jax] Fixed a numerical error in the scaled masked softmax kernel.
- [Jax] Fixed output dtype for FP8 GEMM.
- [PyTorch] Fixed a bug that appeared when the FP8 recipe is changed in between training steps.
- [PyTorch] Made miscellaneous fixes in TransformerLayer: Pass missing arguments cu_seqlens and max_seqlen to cross-attention and allow attn_input_format=thd.
- [PyTorch] Fixed a crash when loading checkpoints from previously generated Transformer Engine versions.
- [PyTorch] Made miscellaneous fixes in CPU offloading logic.
- [PyTorch] Fixed a numerical issue in cross-entropy loss.
- [C][PyTorch][Jax] Fixed source installation when using NVTE_FRAMEWORK=all.
- [PyTorch] Fixed a crash in GroupedLinear when using CUDA graphs.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
There are no deprecated features in this release.
Miscellaneous
There are no miscellaneous issues in this release.
v2.4
Release Notes – Release 2.4
Key Features and Enhancements
- [Jax] Added support for Float8CurrentScaling recipe.
- [Jax] Added support for logical partitioning axes in TE Flax modules.
- [Core] Added multiple experimental functions to the C API.
- [PyTorch] Improved performance by caching device properties.
- [PyTorch] Made miscellaneous minor improvements to reduce memory consumption for certain workloads.
- [PyTorch] Added support for MXFP8 recipe when using userbuffers for overlapping TP communication and GEMMs.
- [PyTorch] Reduced the binary size of the framework extension library from 108 MB to 2 MB.
- [PyTorch] Introduced a Boolean parameter, rotary_pos_interleaved, in the MultiheadAttention and TransformerLayer modules for interleaved RoPE.
- [PyTorch] Added support for ignoring tokens in the cross-entropy loss function.
- [PyTorch] Added support for switching among all supported FP8 recipes during training and checkpointing.
- [PyTorch] Added various debugging tools via NVIDIA-DL-Framework-Inspect.
Fixed Issues
- [PyTorch] Fixed a numerical issue when using activation recompute with FP8.
- [PyTorch] Fixed incorrect output dimensions when using return_layernorm_output in the LayerNormLinear and LayerNormMLP modules.
- [PyTorch] Fixed a numerical bug when using sequence parallelism in the LayerNorm and RMSNorm modules with Megatron-LM.
- [PyTorch/Jax] Fixed miscellaneous crashes at import time due to library loading.
- [Jax] Fixed a crash due to partitioning error when using the LayerNorm or LayerNormMLP module with tensor parallelism.
- [PyTorch] Fixed an issue where GIL was held during the entirety of C API calls from the framework extensions, including during NVIDIA® CUDA® kernel execution.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
There are no deprecated features in this release.
Miscellaneous
There are no miscellaneous issues in this release.
v2.3
Release Notes – Release 2.3
Key Features and Enhancements
- [PyTorch] Sped up import of transformer_engine module by moving to a lazy compilation of functions using torch.compile.
- [PyTorch] Enabled FP8 weights when using FSDP.
- [C][PyTorch] Added support for Float8 block scaling recipe, as used in the Deepseek v3 paper, for Hopper GPUs.
- [PyTorch] Made miscellaneous fixes to reduce CPU overhead.
- [PyTorch] Added support for CPU offloading for activation tensors when using FP8 attention.
- [PyTorch] Enabled MXFP8 recipe for the GroupedLinear module.
- [PyTorch] Added a feature to support decoupling the weight gradient compute from the backward function of Transformer Engine modules. This allows users to call backward wgrad and gives them finer-grained control over when gradients are called to support certain advanced parallelism/overlap schemes.
- [PyTorch] Added support for staggered application of rope embedding to a sequence of inputs in a batch, depending on their starting positions.
- [All] Added support for RTX 5090.
Fixed Issues
- [PyTorch] Fixed a numerical bug with use of custom DDP from megatron-core.
- [PyTorch] Fixed a crash when using the checkpoint method for activation recompute on non-Transformer Engine modules.
Known Issues in This Release
There are no known issues in this release.
Breaking Changes in This Release
- [Jax] Praxis layers have been removed, as PAXML is no longer supported.
Deprecated Features
- The installation for Transformer Engine now requires use of the –no-build-isolation flag when using PyPI package or building from source. Support for installations with build isolation will be removed in a future release.
- [PyTorch] CPU offloading weight tensors is deprecated.
v2.2
Release Notes – Release 2.2
Key Features and Enhancements
- [PyTorch] Added support for per-tensor current scaling recipe.
- [PyTorch] Implemented cross-entropy loss with support for splitting computation across multiple devices.
- [PyTorch] Added support for CPU offloading with Megatron-Core style distributed optimizers.
- [PyTorch] Added support for KV cache for FusedAttention, FlashAttention, and UnfusedDotProductAttention backends.
- [PyTorch] Improved bulk TP communication overlap by launching GEMMs on lower priority streams.
- [C/PyTorch] Improved performance for P2P-based Tensor Parallel (TP) communication overlap.
- [Jax] Added support for THD format with ring attention.
- [Jax] Improved performance and memory usage for causal mask in the cuDNN attention backend.
- [C] Added multi-node support for NVIDIA® NVLink for TP overlap with userbuffers.
Fixed Issues
- [PyTorch] Fixed convergence when using context parallelism with a fused attention backend.
- [PyTorch] Fixed a crash using GroupedLinear when the last input has no tokens.
- [PyTorch] Made miscellaneous fixes to improve overall performance of the MXFP8 recipe.
- [PyTorch] Reintroduced support for return_bias argument to all modules, which was silently ignored in v2.0 and v2.1.
- [PyTorch] Reintroduced support for FP8 communication for overlapping reduce-scatter and GEMM when using TP overlap with userbuffers.
- [PyTorch] Fixed gradient accumulation fusion in the LayerNormMLP module.
- [C/PyTorch] Made miscellaneous numerical fixes to the fused attention backend.
- [C] Avoided creating a new cublasLtHandle for every GEMM call to avoid memory leaks.
- [Jax] Fixed shape and sharding inference in fused-attention C++ extension.
- [Jax] Fixed an import error in the encoder example.
Known Issues in This Release
- RTX 5090 is currently unsupported for FP8 execution. Support will be added in v2.3.0.
- Transformer Engine may crash when it is installed via the PyPI registry but is run in an environment with CUDA version < 12.8. A temporary workaround is to install from source until the issue is fixed.
Breaking Changes in This Release
- [PyTorch] The deprecated interval argument for the DelayedScaling recipe has been removed.
- [PyTorch] There are multiple breaking changes in the InferenceParams class.
- New arguments num_heads_kv, head_dim_k, and dtype are required during initialization.
- The user must call a pre_step method to update the InferenceParams state.
- The swap_key_value_dict method has been removed, as the step method now automatically reorders the key/value sequences according to their batch indices.
Deprecated Features
There are no deprecated features in this release.
Miscellaneous
- [PyTorch] The minimum required PyTorch version is changed to 2.1.
v2.1
Release Notes – Release 2.1
Key Features and Enhancements
- [PyTorch] Made the API for fused optimizers (Adam and SGD) consistent with the PyTorch equivalents.
- [PyTorch] Implemented probability permutation and mask-based permutation in MoE.
- [PyTorch] Added the store_param_remainders argument for TE optimizers to save memory when storing FP32 master weights for BF16 model weights.
- [Jax] Added support for THD attention input format for the flax modules.
Fixed Issues
- [PyPI] Fixed an issue when TE is installed from PyPI in an environment where TE has already been installed from source. The wheel installation was incorrect, resulting in an application crash at runtime.
- [PyTorch] Fixed an issue with QuantizedTensor types when executing operations such as chunk or split, which have different shapes for input and output.
- [PyTorch] Made miscellaneous fixes to attention backend for execution on blackwell GPUs.
- [PyTorch] Fixed a crash when using Context Parallelism with FP8 weights.
- [PyTorch] Fixed a crash when using fused gradient accumulation with grouped GEMMs (MoE).
- [Jax/Flax] Changed flax modules to use dtype to initialize their parameters while inferring compute type from the input data type.
Known Issues in This Release
- [PyTorch] The
return_biasoption in LayerNormLinear and LayerNormMLP, used internally in TransformerLayer, is silently ignored in this release, resulting in a wrong answer. This issue was resolved in #1569 and the fix will be part of the 2.2 release.
Breaking Changes in This Release
There are no breaking changes in this release.
Deprecated Features
- [Jax] The fused_attn_thd API call is deprecated in favor of fused_attn, which supports THD format.
- [Jax] The mask positional argument is deprecated in favor of sequence_descriptor.