-
Notifications
You must be signed in to change notification settings - Fork 213
Description
Describe the bug
I encountered a failure when exporting a quantized model (using NVFP4_DEFAULT_CFG) to ONNX.
The quantization and training steps run normally, but torch.onnx.export crashes with a long stack trace containing shape inference warnings and an internal error triggered by tensorrt::dynamic_block_quantize_op and DynamicBlockQuantizationFunction. Also, I encountered the same issue when using mtq.FP8_DEFAULT_CFG.
The issue seems related to exporting dynamic FP4 quantization ops in a simple Conv2d + Linear model. The same model exports correctly before quantization.
This blocks ONNX export when using ModelOpt NVFP4 quantization.
Steps/Code to reproduce bug
The issue can be reproduced with the following minimal script,the following self-contained script can be run directly and reproduces the issue::
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg
class TwoLinearModel(nn.Module):
def __init__(self):
super().__init__()
# Conv2d,核大小 1×3,相当于在长度维度上卷积
self.conv = nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=(1, 3),
padding=(0, 1) # 保证最后 width = 128
)
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
# 原始输入: (B,128)
x = x.unsqueeze(1).unsqueeze(2) # (B,1,1,128)
x = torch.relu(self.conv(x)) # (B,1,1,128)
x = x.squeeze(2).squeeze(1) # (B,128)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def get_model():
return TwoLinearModel()
def get_dataloader(num_samples=100):
X = torch.randn(num_samples, 128)
y = torch.randint(0, 10, (num_samples,))
dataset = TensorDataset(X, y)
return DataLoader(dataset, batch_size=16, shuffle=False)
# -----------------------------------------------------
# 量化 + 训练 + 导出
# -----------------------------------------------------
device = torch.device("cuda")
model = get_model().to(device)
config = mtq.NVFP4_DEFAULT_CFG
# config = mtq.FP8_DEFAULT_CFG
calib_size = 100
data_loader = get_dataloader(calib_size)
def forward_loop(model):
for batch in data_loader:
x, _ = batch
x = x.to(device)
model(x)
model = mtq.quantize(model, config, forward_loop)
mtq.print_quant_summary(model)
# 训练几步(需要放到 GPU)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_loader = get_dataloader(200)
model.train()
for step, (x, y) in enumerate(train_loader):
if step >= 5:
break
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
print(f"Training step {step}, loss = {loss.item():.4f}")
# 导出 ONNX(输入也要在 GPU)
sample_input = torch.randn(2, 128).to(device)
torch.onnx.export(
model,
sample_input,
"2.onnx",
opset_version=19,
do_constant_folding=True
)
print("Quantization + Training + ONNX export done.")
The export fails with warnings such as:
Missing shape inference for trt::TRT_FP4QDQ
DynamicBlockQuantizationFunction tracing errors
Error triggered inside tensorrt::dynamic_block_quantize_op
The full error log ends with ONNX export failure.
Expected behavior
The quantized model should export to ONNX successfully, or ModelOpt should provide a supported export path for NVFP4 quantized Conv/Linear layers. At minimum, the exporter should fail with a clear error message instead of a low-level internal crash.
Who can help?
Not sure, but likely someone from the ModelOpt quantization or ONNX export team.
System information
OS: Ubuntu 22.04
GPU: NVIDIA H100
Number of GPUs: 1
Python: 3.12
ModelOpt: 0.39.0
CUDA: 12.9
PyTorch: 2.8.0
TensorRT / TensorRT-LLM / ONNXRuntime: default versions in ModelOpt 0.39.0 container
Container: nvidia-modelopt official container
Other: FlashAttention 2.8.3 is installed (ModelOpt warns it requires <= 2.8.2, but I ignored the warning)