Skip to content

v1.6

Compare
Choose a tag to compare
@ptrendx ptrendx released this 13 May 16:36
· 349 commits to main since this release

Release Notes – Release 1.6

Key Features and Enhancements

  • [pyTorch] Added a new make_graphed_callables API call for NVIDIA® CUDA® graph capture, including FP8 support.
  • [pyTorch] Added beta support for two boolean arguments in the DelayedScaling FP8 recipe (fp8_dpa and fp8_mha) to support FP8 attention. Note that the API exposure of this feature may change in future releases.

Fixed Issues

  • [pyTorch] Fixed a numerical issue with storing weights in FP8 via the fp8_model_init API call.
  • [pyTorch] Fixed a bug that caused PyTorch modules to use excessive memory when training with frozen weights by storing unnecessary activations for the backward pass.
  • [JAX] Fixed a bug that caused an incorrect shape to be passed for LayerNorm gradient.

Known Issues in This Release

These issues are unchanged from the previous release.

FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.

[pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.