Skip to content

FP8 model per channel weight quantification model export #621

@IVANYAUN

Description

@IVANYAUN

Describe the bug

I need to use ModelOpt's FP8 quantization functionality to perform per-channel weight quantization and export the quantized model to ONNX format. However, I'm encountering issues during the export phase.
Error message:
torch.onnx.errors.SymbolicValueError: Expected node type 'onnx::Constant' for argument 'amax' of node 'symbolic', got 'onnx::Max'. [Caused by the value '12 defined in (%12 : Float(8, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::Max(%10, %11), scope: __main__.TinyLinear::/modelopt.torch.opt.dynamic.QuantLinear::linear/modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer::input_quantizer # /home/yuanxinwei/quant-tool/TensorRT-Model-Optimizer/modelopt/torch/quantization/utils.py:184:0 )' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Max'.] (node defined in /home/yuanxinwei/quant-tool/TensorRT-Model-Optimizer/modelopt/torch/quantization/utils.py(184): reduce_amax

Steps/Code to reproduce bug

Here is the reproducible code

import copy
import sys
import tempfile
import unittest
from contextlib import nullcontext
from pathlib import Path

import torch
import torch.nn as nn

PROJECT_ROOT = Path(__file__).resolve().parents[2]
MODEL_OPT_PATH = PROJECT_ROOT / "TensorRT-Model-Optimizer"
if MODEL_OPT_PATH.exists():
    sys.path.insert(0, str(MODEL_OPT_PATH))

try:
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization.utils import export_torch_mode

    MODEL_OPT_AVAILABLE = True
except Exception:  # pragma: no cover - optional dependency
    MODEL_OPT_AVAILABLE = False
    mtq = None  # type: ignore
    export_torch_mode = nullcontext  # type: ignore


class TinyLinear(nn.Module):
    """Toy network with a single Linear block quantized to FP8."""

    def __init__(self, in_features: int = 32, out_features: int = 16):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


def build_calib_data(batch_size: int = 8, num_batches: int = 16, in_features: int = 32):
    torch.manual_seed(0)
    return [torch.randn(batch_size, in_features) for _ in range(num_batches)]


def _forward_loop_factory(dataset, device):
    def _forward_loop(model):
        model.eval()
        with torch.no_grad():
            for batch in dataset:
                model(batch.to(device))

    return _forward_loop


class TestStandaloneModelOptFP8(unittest.TestCase):
    @unittest.skipUnless(MODEL_OPT_AVAILABLE, "ModelOpt backend not available")
    def test_fp8_linear_without_smk(self):
        batch_size = 8
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = TinyLinear().to(device).eval()

        calib_data = build_calib_data(batch_size=batch_size)
        fp8_cfg = copy.deepcopy(mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG)
        model = mtq.quantize(
            model,
            fp8_cfg,
            forward_loop=_forward_loop_factory(calib_data, device),
        )

        sample_input = calib_data[0].to(device)
        with torch.no_grad():
            output = model(sample_input)

        self.assertEqual(tuple(output.shape), (batch_size, 16))

        with tempfile.TemporaryDirectory() as tmpdir:
            onnx_path = Path(tmpdir) / "fp8_linear.onnx"
            with export_torch_mode():
                torch.onnx.export(
                    model,
                    (sample_input,),
                    onnx_path,
                    opset_version=19,
                    input_names=["input"],
                    output_names=["output"],
                    keep_initializers_as_inputs=False,
                    dynamo=False,
                )
            self.assertTrue(onnx_path.exists())


if __name__ == "__main__":  # pragma: no cover - direct execution helper
    unittest.main()

Expected behavior

Capable of exporting FP8 perchannel quantization model

System information

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 20.04.6 LTS
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): NVIDIA GeForce RTX 3090
  • GPU memory size: 24.0 GB
  • Number of GPUs: 1
  • Library versions (if applicable):
    • Python: 3.13.2
    • ModelOpt version or commit hash: be64f6b
    • CUDA: 11.8
    • PyTorch: 2.8.0+cu128
    • Transformers: 4.52.4
    • TensorRT-LLM: ?
    • ONNXRuntime: 1.21.0
    • TensorRT: ?
  • Any other details that may help: ?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions