Skip to content

Commit c293d1f

Browse files
committed
Added Segmentation / Export End2End SegModels #16
1 parent 98d4521 commit c293d1f

File tree

5 files changed

+1110
-17
lines changed

5 files changed

+1110
-17
lines changed

README.md

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# YOLOv9 QAT for TensorRT
1+
# YOLOv9 QAT for TensorRT Detection / Segmentation
22

33
This repository contains an implementation of YOLOv9 with Quantization-Aware Training (QAT), specifically designed for deployment on platforms utilizing TensorRT for hardware-accelerated inference. <br>
44
This implementation aims to provide an efficient, low-latency version of YOLOv9 for real-time detection applications.<br>
@@ -15,7 +15,8 @@ We use [TensorRT's pytorch quntization tool](https://github.com/NVIDIA/TensorRT/
1515
For those who are not familiar with QAT, I highly recommend watching this video:<br> [Quantization explained with PyTorch - Post-Training Quantization, Quantization-Aware Training](https://www.youtube.com/watch?v=0VdNflU08yA)
1616

1717
**Important**<br>
18-
Currently, quantization is only available for object detection models. However, since quantization primarily affects the backbone of the YOLOv9 model and the backbone remains consistent across all YOLOv9 variants, quantization is effectively prepared for all YOLOv9-based models, regardless of whether they are used for detection or segmentation tasks. Quantization support for segmentation models has not yet been released, as it necessitates the development of evaluation criteria and the validation of quantization for the final layers of the model. <br>
18+
Evaluation of the segmentation model using TensorRT is currently under development. Once I have more available time, I will complete and release this work.
19+
1920
🌟 We still have plenty of nodes to improve Q/DQ, and we rely on the community's contribution to enhance this project, benefiting us all. Let's collaborate and make it even better! 🚀
2021

2122
## Release Highlights
@@ -35,6 +36,7 @@ Currently, quantization is only available for object detection models. However,
3536

3637
### Evaluation Results
3738

39+
## Detection
3840
#### Activation SiLU
3941

4042
| Eval Model | AP | AP50 | Precision | Recall |
@@ -66,6 +68,14 @@ Currently, quantization is only available for object detection models. However,
6668
| **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | |
6769
| | -0.002 | -0.005 | +0.004 | -0.003 |
6870

71+
## Segmentation
72+
| Model | Box | | | | Mask | | | |
73+
|--------|-----|--|--|--|------|--|--|--|
74+
| | P | R | mAP50 | mAP50-95 | P | R | mAP50 | mAP50-95 |
75+
| Origin | 0.729 | 0.632 | 0.691 | 0.521 | 0.717 | 0.611 | 0.657 | 0.423 |
76+
| PTQ | 0.729 | 0.626 | 0.688 | 0.520 | 0.717 | 0.604 | 0.654 | 0.421 |
77+
| QAT | 0.725 | 0.631 | 0.689 | 0.521 | 0.714 | 0.609 | 0.655 | 0.421 |
78+
6979

7080
## Latency/Throughput Report - TensorRT
7181

@@ -530,3 +540,35 @@ D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%)
530540
Total Host Walltime: 10.0286 s
531541
Total GPU Compute Time: 10.0269 s
532542
```
543+
544+
545+
# Segmentation
546+
547+
## FP16
548+
### Batch Size 8
549+
550+
```bash
551+
=== Performance summary ===
552+
Throughput: 124.055 qps
553+
Latency: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
554+
Enqueue Time: min = 0.00219727 ms, max = 0.0200653 ms, mean = 0.00271174 ms, median = 0.00256348 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00317383 ms, percentile(99%) = 0.00466919 ms
555+
H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
556+
GPU Compute Time: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
557+
D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
558+
Total Host Walltime: 3.01478 s
559+
Total GPU Compute Time: 3.01415 s
560+
```
561+
562+
## INT8 / FP16
563+
### Batch Size 8
564+
```bash
565+
=== Performance summary ===
566+
Throughput: 223.63 qps
567+
Latency: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
568+
Enqueue Time: min = 0.00219727 ms, max = 0.00854492 ms, mean = 0.00258152 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00305176 ms, percentile(99%) = 0.00439453 ms
569+
H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
570+
GPU Compute Time: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
571+
D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
572+
Total Host Walltime: 3.00944 s
573+
Total GPU Compute Time: 3.00836 s
574+
```

export_qat.py

+56-15
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
if platform.system() != 'Windows':
2727
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
2828

29-
from models.experimental import attempt_load, End2End
29+
from models.experimental_trt import End2End_TRT
30+
from models.experimental import attempt_load
3031
from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
3132
from utils.dataloaders import LoadImages
3233
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
@@ -175,12 +176,22 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
175176
remove_redundant_qdq_model(model_onnx, f)
176177
model_onnx = onnx.load(f)
177178
return f, model_onnx
178-
179+
180+
179181

180182
@try_export
181-
def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, labels, prefix=colorstr('ONNX END2END:')):
182-
if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel):
183+
def export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, labels, mask_resolution, pooler_scale, sampling_ratio, prefix=colorstr('ONNX END2END:')):
184+
if not isinstance(model, DetectionModel) or not isinstance(model, SegmentationModel):
183185
raise RuntimeError("Model not supported. Only Detection Models can be exported with End2End functionality.")
186+
187+
is_det_model=True
188+
if isinstance(model, SegmentationModel):
189+
is_det_model=False
190+
191+
env_is_det_model = os.getenv("MODEL_DET")
192+
if env_is_det_model == "0":
193+
is_det_model = False
194+
184195
# YOLO ONNX export
185196
check_requirements('onnx')
186197
import onnx
@@ -195,6 +206,14 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
195206
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
196207
f = os.path.splitext(file)[0] + "-end2end.onnx"
197208
batch_size = 'batch'
209+
d = {
210+
'stride': int(max(model.stride)),
211+
'names': model.names,
212+
'model type' : 'Detection' if is_det_model else 'Segmentation',
213+
'TRT Compatibility': '8.6 or above' if class_agnostic else '8.5 or above',
214+
'TRT Plugins': 'EfficientNMS_TRT' if is_det_model else 'EfficientNMSX_TRT, ROIAlign'
215+
}
216+
198217

199218
dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes
200219

@@ -204,30 +223,40 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
204223
'det_scores': {0: 'batch'},
205224
'det_classes': {0: 'batch'},
206225
}
207-
dynamic_axes.update(output_axes)
208-
model = End2End(model, topk_all, iou_thres, conf_thres, None ,device, labels)
226+
if is_det_model:
227+
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
228+
shapes = [ batch_size, 1,
229+
batch_size, topk_all, 4,
230+
batch_size, topk_all,
231+
batch_size, topk_all]
232+
233+
else:
234+
output_axes['det_masks'] = {0: 'batch'}
235+
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_masks']
236+
shapes = [ batch_size, 1,
237+
batch_size, topk_all, 4,
238+
batch_size, topk_all,
239+
batch_size, topk_all,
240+
batch_size, topk_all, mask_resolution * mask_resolution]
209241

210-
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
211-
shapes = [ batch_size, 1, batch_size, topk_all, 4,
212-
batch_size, topk_all, batch_size, topk_all]
213-
242+
dynamic_axes.update(output_axes)
243+
model = End2End_TRT(model, class_agnostic, topk_all, iou_thres, conf_thres, mask_resolution, pooler_scale, sampling_ratio, None ,device, labels, is_det_model )
214244

215-
245+
216246
if is_model_qat:
217247
warnings.filterwarnings("ignore")
218248
LOGGER.info(f'{prefix} Model QAT Detected ...')
219249
quant_nn.TensorQuantizer.use_fb_fake_quant = True
220250
model.eval()
221251
quantize.initialize()
222-
quantize.replace_custom_module_forward(model)
223252

224253
with torch.no_grad():
225254
torch.onnx.export(model,
226255
im,
227256
f,
228257
verbose=False,
229258
export_params=True, # store the trained parameter weights inside the model file
230-
opset_version=13,
259+
opset_version=16,
231260
do_constant_folding=True, # whether to execute constant folding for optimization
232261
input_names=['images'],
233262
output_names=output_names,
@@ -239,7 +268,7 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
239268
f,
240269
verbose=False,
241270
export_params=True, # store the trained parameter weights inside the model file
242-
opset_version=12,
271+
opset_version=16,
243272
do_constant_folding=True, # whether to execute constant folding for optimization
244273
input_names=['images'],
245274
output_names=output_names,
@@ -248,6 +277,10 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
248277
# Checks
249278
model_onnx = onnx.load(f) # load onnx model
250279
onnx.checker.check_model(model_onnx) # check onnx model
280+
for k, v in d.items():
281+
meta = model_onnx.metadata_props.add()
282+
meta.key, meta.value = k, str(v)
283+
251284
for i in model_onnx.graph.output:
252285
for j in i.type.tensor_type.shape.dim:
253286
j.dim_param = str(shapes.pop(0))
@@ -586,6 +619,7 @@ def run(
586619
batch_size=1, # batch size
587620
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
588621
include=('torchscript', 'onnx'), # include formats
622+
class_agnostic=False,
589623
half=False, # FP16 half-precision export
590624
inplace=False, # set YOLO Detect() inplace=True
591625
keras=False, # use Keras
@@ -602,6 +636,9 @@ def run(
602636
topk_all=100, # TF.js NMS: topk for all classes to keep
603637
iou_thres=0.45, # TF.js NMS: IoU threshold
604638
conf_thres=0.25, # TF.js NMS: confidence threshold
639+
mask_resolution=56,
640+
pooler_scale=0.25,
641+
sampling_ratio=0,
605642
):
606643
t = time.time()
607644
include = [x.lower() for x in include] # to lowercase
@@ -655,7 +692,7 @@ def run(
655692
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
656693
if onnx_end2end:
657694
labels = model.names
658-
f[2], _ = export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, len(labels))
695+
f[2], _ = export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, len(labels), mask_resolution, pooler_scale, sampling_ratio )
659696
if xml: # OpenVINO
660697
f[3], _ = export_openvino(file, metadata, half)
661698
if coreml: # CoreML
@@ -731,6 +768,10 @@ def parse_opt():
731768
parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep')
732769
parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold')
733770
parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold')
771+
parser.add_argument('--class-agnostic', action='store_true', help='Agnostic NMS (single class)')
772+
parser.add_argument('--mask-resolution', type=int, default=160, help='Mask pooled output.')
773+
parser.add_argument('--pooler-scale', type=float, default=0.25, help='Multiplicative factor used to translate the ROI coordinates. ')
774+
parser.add_argument('--sampling-ratio', type=int, default=0, help='Number of sampling points in the interpolation. Allowed values are non-negative integers.')
734775
parser.add_argument(
735776
'--include',
736777
nargs='+',

0 commit comments

Comments
 (0)