-
Notifications
You must be signed in to change notification settings - Fork 213
Description
How would you like to use ModelOpt
I'm using ModelOpt for FP8 quantization of a CNN-based model (UNet architecture with Conv2d and ConvTranspose2d), followed by TensorRT deployment using DeviceModel.
Issue Description
Problem: FP8 quantized model produces correct results with PyTorch inference (mto.restore), but produces incorrect results when using DeviceModel for TensorRT inference (both FP32 and FP8 precision modes fail (use ‘stronglyTyped’ )).
Key observation:
- The issue occurs with both TensorRT FP32 and FP8, suggesting the problem is not precision-related but related to ONNX export or DeviceModel inference flow.
Expected behavior: DeviceModel inference results should match PyTorch inference.
Actual behavior: DeviceModel produces incorrect/degraded outputs regardless of precision mode.
Steps to Reproduce
Step 1: FP8 QAT Training
import modelopt.torch.quantization as mtq
import modelopt.torch.opt as mto
import re
# Model: UNet with Conv2d and ConvTranspose2d layers
# Input: Multi-channel tensor (variable channels, H, W)
# Output: RGB image (3, H, W)
# FP8 quantization configuration
quant_cfg = mtq.FP8_DEFAULT_CFG.copy()
# Calibration loop
def calibration_loop(qat_model):
qat_model.eval()
with torch.no_grad():
for batch_data in calibration_dataloader:
inputs = batch_data['input'].to(device)
_ = qat_model(inputs, None)
# Apply quantization
quantized_model = mtq.quantize(model, quant_cfg, forward_loop=calibration_loop)
# Disable quantization for ConvTranspose and Embedding layers
def filter_func(name):
pattern = re.compile(r".*(ConvTranspose|embedding).*", re.IGNORECASE)
return pattern.match(name) is not None
mtq.disable_quantizer(quantized_model, filter_func)
# -------- Train Code ------------
# -------- Train Code ------------
# -------- Train Code ------------
# Save checkpoint
mto.save(quantized_model, checkpoint_path)
Step 2: ONNX Export (Following diffusers example)
I followed the ONNX export approach from TensorRT-Model-Optimizer diffusers example:
import torch
from torch.onnx import export as onnx_export
import onnx
import onnx_graphsurgeon as gs
# Step 2.1: Generate FP8 scales (following diffusers approach)
def generate_fp8_scales(model):
"""Temporary solution due to known bug in torch.onnx._dynamo_export"""
for _, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)) and (
hasattr(module, 'input_quantizer') and module.input_quantizer is not None
):
module.input_quantizer._num_bits = 8
module.weight_quantizer._num_bits = 8
module.input_quantizer._amax = module.input_quantizer._amax * (127 / 448.0)
module.weight_quantizer._amax = module.weight_quantizer._amax * (127 / 448.0)
generate_fp8_scales(quantized_model)
# Step 2.2: Export to ONNX
dummy_input = torch.randn(1, in_channels, H, W).to(device)
dummy_kwargs = {"x": dummy_input}
with torch.inference_mode():
onnx_export(
quantized_model,
(),
kwargs=dummy_kwargs,
f=tmp_onnx_path,
input_names=['x'],
output_names=['output'],
do_constant_folding=True,
opset_version=20,
dynamo=False
)
# Step 2.3: Convert zero-point datatype from INT8 to FP8 (following diffusers)
def convert_zp_fp8(onnx_graph):
"""
Convert Q/DQ zero datatype from INT8 to FP8.
Workaround: FP8 Conv cannot be exported to ONNX directly.
"""
qdq_zero_nodes = set()
for node in onnx_graph.graph.node:
if node.op_type == "QuantizeLinear" and len(node.input) > 2:
qdq_zero_nodes.add(node.input[2])
for node in onnx_graph.graph.node:
if node.output[0] in qdq_zero_nodes:
node.attribute[0].t.data_type = onnx.TensorProto.FLOAT8E4M3FN
return onnx_graph
# Apply graph transformations
onnx_model = onnx.load(tmp_onnx_path, load_external_data=True)
graph = gs.import_onnx(onnx_model)
graph.cleanup().toposort()
onnx_model = gs.export_onnx(graph)
onnx_model = convert_zp_fp8(onnx_model)
graph = gs.import_onnx(onnx_model)
onnx_model = gs.export_onnx(graph.cleanup())
onnx.save(onnx_model, final_onnx_path)---
Step 3: PyTorch Inference (Works Correctly)
import modelopt.torch.opt as mto
# Load model
model = create_unet_model(...)
model = mto.restore(model, checkpoint_path)
model.eval()
# Run inference
with torch.no_grad():
output = model(input_tensor, None)
# Output is correct and matches expected results---
### Step 4: TensorRT Inference via DeviceModel (Incorrect Results)
from modelopt.torch._deploy._runtime import RuntimeRegistry
from modelopt.torch._deploy.device_model import DeviceModel
from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata
# TensorRT deployment configuration
deployment = {
"runtime": "TRT",
"precision": "stronglyTyped",
"onnx_opset": "17",
"verbose": "true",
}
client = RuntimeRegistry.get(deployment)
# Load ONNX and get metadata
dummy_input = torch.randn(1, in_channels, H, W).cuda()
dummy_kwargs = {"x": dummy_input}
onnx_bytes, metadata = get_onnx_bytes_and_metadata(
model=model,
dummy_input=dummy_kwargs,
onnx_load_path=onnx_path,
onnx_opset=20,
remove_exported_model=False,
dq_only=False
)
# Extract IO shapes from metadata
io_shapes = metadata.get('output_shapes', {})
# Compile TensorRT engine and create DeviceModel
compilation_args = {"engine_path": trt_engine_path}
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args=compilation_args)
device_model = DeviceModel(
client,
compiled_model,
metadata,
compilation_args,
io_shapes,
True # ignore_nesting=True
)
# Run inference
device_model.eval()
with torch.no_grad():
output = self.forward(x=input_tensor) # Wrapped function
# Forward function
def forward(self, input_tensor, shading_model_id=None):
"""Forward pass through the model"""
with torch.no_grad():
if isinstance(self.model, DeviceModel): # Device model(tensorrt api)
outputs = self.model(x=input_tensor, shading_model_id=None)
return outputs[0] if isinstance(outputs, list) else outputs
else: # pytorch + modelopt api
return self.model(input_tensor)
# Output is incorrect (both fp32 and fp8 precision modes fail)---
Questions
-
For CNN models with Conv2d and ConvTranspose2d: Do I need additional ONNX graph modifications beyond
generate_fp8_scales()andconvert_zp_fp8()? -
ConvTranspose handling: Are there known issues with ConvTranspose2d layers in FP8 quantized ONNX → TensorRT conversion? Should I exclude them from quantization?
-
ONNX export verification: How can I verify that the exported ONNX graph is correct before TensorRT compilation? Are there any tools to compare PyTorch vs ONNX inference?
-
DeviceModel configuration: Since both FP32 and stronglyTyped fail, is there a potential issue with:
- ONNX graph structure after
convert_zp_fp8()? get_onnx_bytes_and_metadata()parameters?DeviceModelinitialization parameters?
- ONNX graph structure after
Observations
- ✅ PyTorch inference with
mto.restore()works perfectly - ❌ TensorRT FP32 stronglyTyped produces incorrect results
- ❌ TensorRT FP8 stronglyTyped produces incorrect results
- ✅ No errors during ONNX export or TensorRT compilation
- ℹ️ I followed the diffusers example for FP8 ONNX export
Model Architecture Details
- Type: UNet-based CNN
- Key operations:
- Conv2d (quantized)
- ConvTranspose2d (excluded from quantization)
- BatchNorm2d
- ReLU/LeakyReLU
- Input: Variable channel tensor (30 channels, H×W resolution)
- Output: RGB image (3 channels)
- Quantization exclusions: ConvTranspose2d, Embedding layers (if present)
Who can help?
@todo: Tag appropriate maintainers for TensorRT/DeviceModel/CNN quantization
System information
- Container used: N/A (native installation)
- OS: Linux (kernel 5.4.250, Ubuntu-based)
- CPU architecture: x86_64
- GPU name: H20
- GPU memory size: 96GB
- Number of GPUs: 1
- Library versions:
- Python: 3.12
- ModelOpt version: 0.37.0/0.39.0 (both)
- CUDA: 12.9
- PyTorch: 2.8.0+cu129
- TensorRT: 10.12.0.36(modelopt 0.37.0) / 10.13.3.9post(modelopt 0.39.0)
- ONNX: 1.19.1
- onnx-graphsurgeon: 0.5.8
Additional Context
- I verified
DeviceModelconfiguration (io_shapes, ignore_nesting) works correctly with standard pre-trained models (ResNet-18) - The issue appears to be specific to CNN architectures with ConvTranspose or the ONNX export flow for FP8 quantized models
- I'm using the same FP8 ONNX export approach as the diffusers example
- Onnx file
References
- Followed export approach from: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/diffusers/quantization/onnx_utils/export.py
- Followed quantization approach from: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/examples/diffusers/quantization/quantize.py
- Followed onnx ptq examles from: https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/onnx_ptq
Any guidance on correct ONNX export and DeviceModel usage for CNN QAT models would be greatly appreciated!