Skip to content

Multihead Attention fails fp8 ONNX export #2588

@victoroliv2

Description

@victoroliv2

Describe the bug

ONNX export is failing for a standard transformer encoder with fp8 precision. I realized that MHA doesn't seem to be working.

Simple repro:

import torch
from torch import nn

import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention as TF_MHA
from transformer_engine.pytorch.export import te_translation_table
from transformer_engine.common.recipe import Format, DelayedScaling

num_attention_heads = 2
hidden_dim = 128 * num_attention_heads
sequence_length = 8192
batch_dim = 1
num_layers = 3

x = torch.randn([sequence_length, batch_dim, hidden_dim], device="cuda")

class TestMHAModel(torch.nn.Module):
    def __init__(self, hidden_dim, num_attention_heads, num_layers):
        super(TestMHAModel, self).__init__()
        self.layers = nn.Sequential(
            *[TF_MHA(
                hidden_size=hidden_dim,
                num_attention_heads=num_attention_heads,
                attention_dropout=0.0,
                layer_number=i+1,
                attn_mask_type="no_mask",
                window_size=(-1,-1),
                attention_type="self",
                normalization="LayerNorm",
                seq_length=sequence_length) for i in range(num_layers)]
        )

    def forward(self, x):
        return self.layers(x)

def export(fname, x, recipe):
    with te.autocast(enabled=True, recipe=recipe):
        model = TestMHAModel(hidden_dim, num_attention_heads, num_layers).cuda().eval()

        with torch.inference_mode():
            model(x)

            with te.onnx_export(enabled=True):
                model(x)

            with te.onnx_export(enabled=True):
                torch.onnx.export(
                    model,
                    x,
                    fname,
                    output_names=["output"],
                    dynamo=True,
                    custom_translation_table=te_translation_table
                )

# This works but doesn't introduce the right Q/DQ operators so MHA runs in fp32 precision with TensorRt
recipe = DelayedScaling(fp8_mha = False, fp8_dpa = False)
export("mha1.onnx", x, recipe)

# Failed in FP8EmulationFunc.apply
recipe = DelayedScaling(fp8_mha = True, fp8_dpa = True)
export("mha2.onnx", x, recipe)

I had to run these with NVTE_UnfusedDPA_Emulate_FP8=1 python export_mha_bug.py otherwise I get:
"No dot product attention backend is available for the provided inputs"

It fails with the following stack trace:

    query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
                                          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 582, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 171, in forward
    q_fp8, k_fp8, v_fp8 = combine_and_quantize(
                          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/utils.py", line 2190, in combine_and_quantize
    qkv_fp8 = qkv_quantizer(qkv)
              ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 262, in __call__
    return self.quantize(tensor)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 245, in quantize
    return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/_quantization_helpers.py", line 29, in forward
    return quantize_impl(tensor)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 109, in quantize_impl
    return tex.quantize(tensor, self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Expected behavior

ONNX file with Q/DQ operators around the matmuls and MHA fusion in TensorRT for fp8 precision.

Environment overview (please complete the following information)

docker run --gpus all -it --rm \ -v $(pwd)/mount:/mount \ nvcr.io/nvidia/pytorch:25.11-py3

I also tried pulling the latest TransformerEngine (built with pip install inside the container above).

Environment details

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions