Skip to content

ONNX export failure with mtq.FP8_DEFAULT_CFG and mtq.NVFP4_DEFAULT_CFG #614

@Hyubo

Description

@Hyubo

Describe the bug

I encountered a failure when exporting a quantized model (using NVFP4_DEFAULT_CFG) to ONNX.
The quantization and training steps run normally, but torch.onnx.export crashes with a long stack trace containing shape inference warnings and an internal error triggered by tensorrt::dynamic_block_quantize_op and DynamicBlockQuantizationFunction. Also, I encountered the same issue when using mtq.FP8_DEFAULT_CFG.

The issue seems related to exporting dynamic FP4 quantization ops in a simple Conv2d + Linear model. The same model exports correctly before quantization.

This blocks ONNX export when using ModelOpt NVFP4 quantization.

Steps/Code to reproduce bug

The issue can be reproduced with the following minimal script,the following self-contained script can be run directly and reproduces the issue::

import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg

class TwoLinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv2d,核大小 1×3,相当于在长度维度上卷积
        self.conv = nn.Conv2d(
            in_channels=1,
            out_channels=1,
            kernel_size=(1, 3),
            padding=(0, 1)  # 保证最后 width = 128
        )

        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        # 原始输入: (B,128)
        x = x.unsqueeze(1).unsqueeze(2)   # (B,1,1,128)
        x = torch.relu(self.conv(x))      # (B,1,1,128)
        x = x.squeeze(2).squeeze(1)       # (B,128)

        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x



def get_model():
    return TwoLinearModel()


def get_dataloader(num_samples=100):
    X = torch.randn(num_samples, 128)
    y = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(X, y)
    return DataLoader(dataset, batch_size=16, shuffle=False)


# -----------------------------------------------------
# 量化 + 训练 + 导出
# -----------------------------------------------------
device = torch.device("cuda")

model = get_model().to(device)

config = mtq.NVFP4_DEFAULT_CFG
# config = mtq.FP8_DEFAULT_CFG

calib_size = 100
data_loader = get_dataloader(calib_size)


def forward_loop(model):
    for batch in data_loader:
        x, _ = batch
        x = x.to(device)
        model(x)


model = mtq.quantize(model, config, forward_loop)

mtq.print_quant_summary(model)


# 训练几步(需要放到 GPU)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train_loader = get_dataloader(200)

model.train()
for step, (x, y) in enumerate(train_loader):
    if step >= 5:
        break

    x = x.to(device)
    y = y.to(device)

    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()

    print(f"Training step {step}, loss = {loss.item():.4f}")


# 导出 ONNX(输入也要在 GPU)
sample_input = torch.randn(2, 128).to(device)

torch.onnx.export(
    model,
    sample_input,
    "2.onnx",
    opset_version=19,
    do_constant_folding=True
)

print("Quantization + Training + ONNX export done.")

The export fails with warnings such as:

Missing shape inference for trt::TRT_FP4QDQ

DynamicBlockQuantizationFunction tracing errors

Error triggered inside tensorrt::dynamic_block_quantize_op

The full error log ends with ONNX export failure.

Expected behavior

The quantized model should export to ONNX successfully, or ModelOpt should provide a supported export path for NVFP4 quantized Conv/Linear layers. At minimum, the exporter should fail with a clear error message instead of a low-level internal crash.

Who can help?

Not sure, but likely someone from the ModelOpt quantization or ONNX export team.

System information

OS: Ubuntu 22.04

GPU: NVIDIA H100

Number of GPUs: 1

Python: 3.12

ModelOpt: 0.39.0

CUDA: 12.9

PyTorch: 2.8.0

TensorRT / TensorRT-LLM / ONNXRuntime: default versions in ModelOpt 0.39.0 container

Image

Container: nvidia-modelopt official container

Image

Other: FlashAttention 2.8.3 is installed (ModelOpt warns it requires <= 2.8.2, but I ignored the warning)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions