-
Notifications
You must be signed in to change notification settings - Fork 213
Description
Describe the bug
I need to use ModelOpt's FP8 quantization functionality to perform per-channel weight quantization and export the quantized model to ONNX format. However, I'm encountering issues during the export phase.
Error message:
torch.onnx.errors.SymbolicValueError: Expected node type 'onnx::Constant' for argument 'amax' of node 'symbolic', got 'onnx::Max'. [Caused by the value '12 defined in (%12 : Float(8, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::Max(%10, %11), scope: __main__.TinyLinear::/modelopt.torch.opt.dynamic.QuantLinear::linear/modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer::input_quantizer # /home/yuanxinwei/quant-tool/TensorRT-Model-Optimizer/modelopt/torch/quantization/utils.py:184:0 )' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Max'.] (node defined in /home/yuanxinwei/quant-tool/TensorRT-Model-Optimizer/modelopt/torch/quantization/utils.py(184): reduce_amax
Steps/Code to reproduce bug
Here is the reproducible code
import copy
import sys
import tempfile
import unittest
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.nn as nn
PROJECT_ROOT = Path(__file__).resolve().parents[2]
MODEL_OPT_PATH = PROJECT_ROOT / "TensorRT-Model-Optimizer"
if MODEL_OPT_PATH.exists():
sys.path.insert(0, str(MODEL_OPT_PATH))
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode
MODEL_OPT_AVAILABLE = True
except Exception: # pragma: no cover - optional dependency
MODEL_OPT_AVAILABLE = False
mtq = None # type: ignore
export_torch_mode = nullcontext # type: ignore
class TinyLinear(nn.Module):
"""Toy network with a single Linear block quantized to FP8."""
def __init__(self, in_features: int = 32, out_features: int = 16):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def build_calib_data(batch_size: int = 8, num_batches: int = 16, in_features: int = 32):
torch.manual_seed(0)
return [torch.randn(batch_size, in_features) for _ in range(num_batches)]
def _forward_loop_factory(dataset, device):
def _forward_loop(model):
model.eval()
with torch.no_grad():
for batch in dataset:
model(batch.to(device))
return _forward_loop
class TestStandaloneModelOptFP8(unittest.TestCase):
@unittest.skipUnless(MODEL_OPT_AVAILABLE, "ModelOpt backend not available")
def test_fp8_linear_without_smk(self):
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyLinear().to(device).eval()
calib_data = build_calib_data(batch_size=batch_size)
fp8_cfg = copy.deepcopy(mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG)
model = mtq.quantize(
model,
fp8_cfg,
forward_loop=_forward_loop_factory(calib_data, device),
)
sample_input = calib_data[0].to(device)
with torch.no_grad():
output = model(sample_input)
self.assertEqual(tuple(output.shape), (batch_size, 16))
with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = Path(tmpdir) / "fp8_linear.onnx"
with export_torch_mode():
torch.onnx.export(
model,
(sample_input,),
onnx_path,
opset_version=19,
input_names=["input"],
output_names=["output"],
keep_initializers_as_inputs=False,
dynamo=False,
)
self.assertTrue(onnx_path.exists())
if __name__ == "__main__": # pragma: no cover - direct execution helper
unittest.main()
Expected behavior
Capable of exporting FP8 perchannel quantization model
System information
- Container used (if applicable): ?
- OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 20.04.6 LTS
- CPU architecture (x86_64, aarch64): x86_64
- GPU name (e.g. H100, A100, L40S): NVIDIA GeForce RTX 3090
- GPU memory size: 24.0 GB
- Number of GPUs: 1
- Library versions (if applicable):
- Python: 3.13.2
- ModelOpt version or commit hash: be64f6b
- CUDA: 11.8
- PyTorch: 2.8.0+cu128
- Transformers: 4.52.4
- TensorRT-LLM: ?
- ONNXRuntime: 1.21.0
- TensorRT: ?
- Any other details that may help: ?