-
Notifications
You must be signed in to change notification settings - Fork 607
Description
Describe the bug
ONNX export is failing for a standard transformer encoder with fp8 precision. I realized that MHA doesn't seem to be working.
Simple repro:
import torch
from torch import nn
import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention as TF_MHA
from transformer_engine.pytorch.export import te_translation_table
from transformer_engine.common.recipe import Format, DelayedScaling
num_attention_heads = 2
hidden_dim = 128 * num_attention_heads
sequence_length = 8192
batch_dim = 1
num_layers = 3
x = torch.randn([sequence_length, batch_dim, hidden_dim], device="cuda")
class TestMHAModel(torch.nn.Module):
def __init__(self, hidden_dim, num_attention_heads, num_layers):
super(TestMHAModel, self).__init__()
self.layers = nn.Sequential(
*[TF_MHA(
hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
attention_dropout=0.0,
layer_number=i+1,
attn_mask_type="no_mask",
window_size=(-1,-1),
attention_type="self",
normalization="LayerNorm",
seq_length=sequence_length) for i in range(num_layers)]
)
def forward(self, x):
return self.layers(x)
def export(fname, x, recipe):
with te.autocast(enabled=True, recipe=recipe):
model = TestMHAModel(hidden_dim, num_attention_heads, num_layers).cuda().eval()
with torch.inference_mode():
model(x)
with te.onnx_export(enabled=True):
model(x)
with te.onnx_export(enabled=True):
torch.onnx.export(
model,
x,
fname,
output_names=["output"],
dynamo=True,
custom_translation_table=te_translation_table
)
# This works but doesn't introduce the right Q/DQ operators so MHA runs in fp32 precision with TensorRt
recipe = DelayedScaling(fp8_mha = False, fp8_dpa = False)
export("mha1.onnx", x, recipe)
# Failed in FP8EmulationFunc.apply
recipe = DelayedScaling(fp8_mha = True, fp8_dpa = True)
export("mha2.onnx", x, recipe)I had to run these with NVTE_UnfusedDPA_Emulate_FP8=1 python export_mha_bug.py otherwise I get:
"No dot product attention backend is available for the provided inputs"
It fails with the following stack trace:
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 582, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 171, in forward
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/utils.py", line 2190, in combine_and_quantize
qkv_fp8 = qkv_quantizer(qkv)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 262, in __call__
return self.quantize(tensor)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 245, in quantize
return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/_quantization_helpers.py", line 29, in forward
return quantize_impl(tensor)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 109, in quantize_impl
return tex.quantize(tensor, self)
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Expected behavior
ONNX file with Q/DQ operators around the matmuls and MHA fusion in TensorRT for fp8 precision.
Environment overview (please complete the following information)
docker run --gpus all -it --rm \ -v $(pwd)/mount:/mount \ nvcr.io/nvidia/pytorch:25.11-py3
I also tried pulling the latest TransformerEngine (built with pip install inside the container above).
Environment details
Additional context
Add any other context about the problem here.