Skip to content

[Help wanted] CNN QAT: How to correctly export ONNX and use DeviceModel for inference #604

@yy6768

Description

@yy6768

How would you like to use ModelOpt

I'm using ModelOpt for FP8 quantization of a CNN-based model (UNet architecture with Conv2d and ConvTranspose2d), followed by TensorRT deployment using DeviceModel.

Issue Description

Problem: FP8 quantized model produces correct results with PyTorch inference (mto.restore), but produces incorrect results when using DeviceModel for TensorRT inference (both FP32 and FP8 precision modes fail (use ‘stronglyTyped’ )).

Key observation:

  • The issue occurs with both TensorRT FP32 and FP8, suggesting the problem is not precision-related but related to ONNX export or DeviceModel inference flow.

Expected behavior: DeviceModel inference results should match PyTorch inference.

Actual behavior: DeviceModel produces incorrect/degraded outputs regardless of precision mode.

Steps to Reproduce

Step 1: FP8 QAT Training

import modelopt.torch.quantization as mtq
import modelopt.torch.opt as mto
import re

# Model: UNet with Conv2d and ConvTranspose2d layers
# Input: Multi-channel tensor (variable channels, H, W)
# Output: RGB image (3, H, W)

# FP8 quantization configuration
quant_cfg = mtq.FP8_DEFAULT_CFG.copy()

# Calibration loop
def calibration_loop(qat_model):
    qat_model.eval()
    with torch.no_grad():
        for batch_data in calibration_dataloader:
            inputs = batch_data['input'].to(device)
            _ = qat_model(inputs, None)

# Apply quantization
quantized_model = mtq.quantize(model, quant_cfg, forward_loop=calibration_loop)

# Disable quantization for ConvTranspose and Embedding layers
def filter_func(name):
    pattern = re.compile(r".*(ConvTranspose|embedding).*", re.IGNORECASE)
    return pattern.match(name) is not None

mtq.disable_quantizer(quantized_model, filter_func)

# --------   Train Code ------------
# --------   Train Code ------------
# --------   Train Code ------------


# Save checkpoint
mto.save(quantized_model, checkpoint_path)

Step 2: ONNX Export (Following diffusers example)

I followed the ONNX export approach from TensorRT-Model-Optimizer diffusers example:

import torch
from torch.onnx import export as onnx_export
import onnx
import onnx_graphsurgeon as gs

# Step 2.1: Generate FP8 scales (following diffusers approach)
def generate_fp8_scales(model):
    """Temporary solution due to known bug in torch.onnx._dynamo_export"""
    for _, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)) and (
            hasattr(module, 'input_quantizer') and module.input_quantizer is not None
        ):
            module.input_quantizer._num_bits = 8
            module.weight_quantizer._num_bits = 8
            module.input_quantizer._amax = module.input_quantizer._amax * (127 / 448.0)
            module.weight_quantizer._amax = module.weight_quantizer._amax * (127 / 448.0)

generate_fp8_scales(quantized_model)
# Step 2.2: Export to ONNX
dummy_input = torch.randn(1, in_channels, H, W).to(device)
dummy_kwargs = {"x": dummy_input}

with torch.inference_mode():
    onnx_export(
        quantized_model,
        (),
        kwargs=dummy_kwargs,
        f=tmp_onnx_path,
        input_names=['x'],
        output_names=['output'],
        do_constant_folding=True,
        opset_version=20,
        dynamo=False
    )
# Step 2.3: Convert zero-point datatype from INT8 to FP8 (following diffusers)
def convert_zp_fp8(onnx_graph):
    """
    Convert Q/DQ zero datatype from INT8 to FP8.
    Workaround: FP8 Conv cannot be exported to ONNX directly.
    """
    qdq_zero_nodes = set()
    for node in onnx_graph.graph.node:
        if node.op_type == "QuantizeLinear" and len(node.input) > 2:
            qdq_zero_nodes.add(node.input[2])
    
    for node in onnx_graph.graph.node:
        if node.output[0] in qdq_zero_nodes:
            node.attribute[0].t.data_type = onnx.TensorProto.FLOAT8E4M3FN
    
    return onnx_graph

# Apply graph transformations
onnx_model = onnx.load(tmp_onnx_path, load_external_data=True)
graph = gs.import_onnx(onnx_model)
graph.cleanup().toposort()
onnx_model = gs.export_onnx(graph)
onnx_model = convert_zp_fp8(onnx_model)
graph = gs.import_onnx(onnx_model)
onnx_model = gs.export_onnx(graph.cleanup())
onnx.save(onnx_model, final_onnx_path)---

Step 3: PyTorch Inference (Works Correctly)

import modelopt.torch.opt as mto

# Load model
model = create_unet_model(...)
model = mto.restore(model, checkpoint_path)
model.eval()

# Run inference
with torch.no_grad():
    output = model(input_tensor, None)

# Output is correct and matches expected results---

### Step 4: TensorRT Inference via DeviceModel (Incorrect Results)

from modelopt.torch._deploy._runtime import RuntimeRegistry
from modelopt.torch._deploy.device_model import DeviceModel
from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata

# TensorRT deployment configuration
deployment = {
    "runtime": "TRT",
    "precision": "stronglyTyped",  
    "onnx_opset": "17",
    "verbose": "true",
}

client = RuntimeRegistry.get(deployment)
# Load ONNX and get metadata
dummy_input = torch.randn(1, in_channels, H, W).cuda()
dummy_kwargs = {"x": dummy_input}

onnx_bytes, metadata = get_onnx_bytes_and_metadata(
    model=model,
    dummy_input=dummy_kwargs,
    onnx_load_path=onnx_path,
    onnx_opset=20,
    remove_exported_model=False,
    dq_only=False
)

# Extract IO shapes from metadata
io_shapes = metadata.get('output_shapes', {})
# Compile TensorRT engine and create DeviceModel
compilation_args = {"engine_path": trt_engine_path}
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args=compilation_args)

device_model = DeviceModel(
    client,
    compiled_model,
    metadata,
    compilation_args,
    io_shapes,
    True  # ignore_nesting=True
)
# Run inference
device_model.eval()

with torch.no_grad():
    output = self.forward(x=input_tensor) # Wrapped function

# Forward function
 def forward(self, input_tensor, shading_model_id=None):
        """Forward pass through the model"""
        with torch.no_grad():
            if isinstance(self.model, DeviceModel): # Device model(tensorrt api)
                outputs = self.model(x=input_tensor, shading_model_id=None)
                return outputs[0] if isinstance(outputs, list) else outputs
            else: # pytorch + modelopt api
                return self.model(input_tensor)

# Output is incorrect (both fp32 and fp8 precision modes fail)---

Questions

  1. For CNN models with Conv2d and ConvTranspose2d: Do I need additional ONNX graph modifications beyond generate_fp8_scales() and convert_zp_fp8()?

  2. ConvTranspose handling: Are there known issues with ConvTranspose2d layers in FP8 quantized ONNX → TensorRT conversion? Should I exclude them from quantization?

  3. ONNX export verification: How can I verify that the exported ONNX graph is correct before TensorRT compilation? Are there any tools to compare PyTorch vs ONNX inference?

  4. DeviceModel configuration: Since both FP32 and stronglyTyped fail, is there a potential issue with:

    • ONNX graph structure after convert_zp_fp8()?
    • get_onnx_bytes_and_metadata() parameters?
    • DeviceModel initialization parameters?

Observations

  1. ✅ PyTorch inference with mto.restore() works perfectly
  2. ❌ TensorRT FP32 stronglyTyped produces incorrect results
  3. ❌ TensorRT FP8 stronglyTyped produces incorrect results
  4. ✅ No errors during ONNX export or TensorRT compilation
  5. ℹ️ I followed the diffusers example for FP8 ONNX export

Model Architecture Details

  • Type: UNet-based CNN
  • Key operations:
    • Conv2d (quantized)
    • ConvTranspose2d (excluded from quantization)
    • BatchNorm2d
    • ReLU/LeakyReLU
  • Input: Variable channel tensor (30 channels, H×W resolution)
  • Output: RGB image (3 channels)
  • Quantization exclusions: ConvTranspose2d, Embedding layers (if present)

Who can help?

@todo: Tag appropriate maintainers for TensorRT/DeviceModel/CNN quantization

System information

  • Container used: N/A (native installation)
  • OS: Linux (kernel 5.4.250, Ubuntu-based)
  • CPU architecture: x86_64
  • GPU name: H20
  • GPU memory size: 96GB
  • Number of GPUs: 1
  • Library versions:
    • Python: 3.12
    • ModelOpt version: 0.37.0/0.39.0 (both)
    • CUDA: 12.9
    • PyTorch: 2.8.0+cu129
    • TensorRT: 10.12.0.36(modelopt 0.37.0) / 10.13.3.9post(modelopt 0.39.0)
    • ONNX: 1.19.1
    • onnx-graphsurgeon: 0.5.8

Additional Context

  • I verified DeviceModel configuration (io_shapes, ignore_nesting) works correctly with standard pre-trained models (ResNet-18)
  • The issue appears to be specific to CNN architectures with ConvTranspose or the ONNX export flow for FP8 quantized models
  • I'm using the same FP8 ONNX export approach as the diffusers example
  • Onnx file
Image

References


Any guidance on correct ONNX export and DeviceModel usage for CNN QAT models would be greatly appreciated!

Metadata

Metadata

Assignees

Labels

questionHelp is is needed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions