diff --git a/yolov9/export_onnxtrt.py b/yolov9/export_onnxtrt.py new file mode 100644 index 0000000..959f709 --- /dev/null +++ b/yolov9/export_onnxtrt.py @@ -0,0 +1,246 @@ +import argparse +import os +import platform +import sys +import time +from pathlib import Path +import pandas as pd +import torch + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] # YOLO root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +if platform.system() != 'Windows': + ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from models.experimental import attempt_load +from models.experimental_trt import End2End_TRT +from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel, DSegment +from utils.general import (LOGGER, Profile, check_img_size, check_requirements, + colorstr, file_size, get_default_args, print_args, url2file) +from utils.torch_utils import select_device, smart_inference_mode +from torch.jit import TracerWarning +import warnings + +warnings.filterwarnings("ignore", category=TracerWarning) +warnings.filterwarnings("ignore", category=FutureWarning ) + +MACOS = platform.system() == 'Darwin' # macOS environment + +def export_formats(): + # YOLO export formats + x = [ + ['PyTorch', '-', '.pt', True, True], + ['ONNX TRT', 'onnx_trt', '_trt.onnx', True, True], + ] + return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) + + +def try_export(inner_func): + # YOLO export decorator, i..e @try_export + inner_args = get_default_args(inner_func) + + def outer_func(*args, **kwargs): + prefix = inner_args['prefix'] + try: + with Profile() as dt: + f, model = inner_func(*args, **kwargs) + LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)') + return f, model + except Exception as e: + LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}') + return None, None + + return outer_func + +@try_export +def export_onnx_trt(model, im, file, class_agnostic, topk_all, iou_thres, conf_thres, device, labels, mask_resolution, pooler_scale, sampling_ratio, prefix=colorstr('ONNX TRT:')): + is_det_model=True + if isinstance(model, SegmentationModel): + is_det_model=False + + ## force SegmentationModel + env_is_det_model = os.getenv("is_det_model") + if env_is_det_model == "0": + is_det_model = False + # YOLO ONNX export + check_requirements('onnx') + import onnx + LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') + f = os.path.splitext(file)[0] + "-trt.onnx" + batch_size = 'batch' + d = { + 'stride': int(max(model.stride)), + 'names': model.names, + 'model type' : 'Detection' if is_det_model else 'Segmentation', + 'TRT Compatibility': '8.6 or above', + 'TRT Plugins': 'YoloNMS' if is_det_model else 'YoloNMS, ROIAlign' + } + + dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes + + output_axes = { + 'num_dets': {0: 'batch'}, + 'det_boxes': {0: 'batch'}, + 'det_scores': {0: 'batch'}, + 'det_classes': {0: 'batch'}, + } + + if is_det_model: + output_axes['det_indices'] = {0: 'batch'} + output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_indices'] + shapes = [ batch_size, 1, + batch_size, topk_all, 4, + batch_size, topk_all, + batch_size, topk_all, + batch_size, topk_all] + + else: + output_axes['det_masks'] = {0: 'batch'} + output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_masks'] + shapes = [ batch_size, 1, + batch_size, topk_all, 4, + batch_size, topk_all, + batch_size, topk_all, + batch_size, topk_all, mask_resolution * mask_resolution] + + dynamic_axes.update(output_axes) + + model = End2End_TRT(model, class_agnostic, topk_all, iou_thres, conf_thres, mask_resolution, pooler_scale, sampling_ratio, None ,device, labels, is_det_model ) + + torch.onnx.export(model, + im, + f, + verbose=False, + export_params=True, # store the trained parameter weights inside the model file + opset_version=14, + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['images'], + output_names=output_names, + dynamic_axes=dynamic_axes) + + # Checks + model_onnx = onnx.load(f) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + + for k, v in d.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + + for i in model_onnx.graph.output: + for j in i.type.tensor_type.shape.dim: + j.dim_param = str(shapes.pop(0)) + + check_requirements('onnxsim') + try: + import onnxsim + LOGGER.info(f'\n{prefix} Starting to simplify ONNX...') + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, 'assert check failed' + except Exception as e: + LOGGER.info(f'\n{prefix} Simplifier failure: {e}') + + onnx.save(model_onnx,f) + + check_requirements('onnx_graphsurgeon') + LOGGER.info(f'\n{prefix} Starting to cleanup ONNX using onnx_graphsurgeon...') + try: + import onnx_graphsurgeon as gs + + graph = gs.import_onnx(model_onnx) + graph = graph.cleanup().toposort() + model_onnx = gs.export_onnx(graph) + except Exception as e: + LOGGER.info(f'\n{prefix} Cleanup failure: {e}') + + return f, model_onnx + + +@smart_inference_mode() +def run( + weights=ROOT / 'yolo.pt', # weights path + imgsz=(640, 640), # image (height, width) + device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu + include=('onnx_trt',), # include formats + class_agnostic=False, # TF: add agnostic NMS to model + topk_all=100, # TF.js NMS: topk for all classes to keep + iou_thres=0.45, # TF.js NMS: IoU threshold + conf_thres=0.25, # TF.js NMS: confidence threshold + mask_resolution=56, + pooler_scale=0.25, + sampling_ratio=0, +): + t = time.time() + include = [x.lower() for x in include] # to lowercase + fmts = tuple(export_formats()['Argument'][1:]) # --include arguments + flags = [x in include for x in fmts] + assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' + onnx_trt = flags # export booleans + file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights + + # Load PyTorch model + device = select_device(device) + model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model + # Checks + imgsz *= 2 if len(imgsz) == 1 else 1 # expand + + # Input + gs = int(max(model.stride)) # grid size (max stride) + imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples + im = torch.zeros(1, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection + + # Update model + model.eval() + for k, m in model.named_modules(): + if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): + m.inplace = True + m.dynamic = True + m.export = True + + for _ in range(2): + y = model(im) # dry runs + + shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape + LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") + + # Exports + f = [''] * len(fmts) # exported filenames + if onnx_trt: + labels = model.names + f[0], _ = export_onnx_trt(model, im, file, class_agnostic, topk_all, iou_thres, conf_thres, device, len(labels), mask_resolution, pooler_scale, sampling_ratio ) + # Finish + f = [str(x) for x in f if x] + LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f"\nVisualize: https://netron.app") + return f # return list of exported files/dirs + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') + parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--class-agnostic', action='store_true', help='TF: add agnostic NMS to model') + parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep') + parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold') + parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold') + parser.add_argument('--mask-resolution', type=int, default=160, help='ONNX END2END/TF.js NMS: confidence threshold') + parser.add_argument('--pooler-scale', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold') + parser.add_argument('--sampling-ratio', type=int, default=0, help='ONNX END2END/TF.js NMS: confidence threshold') + parser.add_argument('--include', nargs='+', default=['onnx_trt'], help='onnx_trt') + + opt = parser.parse_args() + + print_args(vars(opt)) + return opt + +def main(opt): + for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]): + run(**vars(opt)) + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/yolov9/models/experimental_trt.py b/yolov9/models/experimental_trt.py new file mode 100644 index 0000000..f437714 --- /dev/null +++ b/yolov9/models/experimental_trt.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn + +class TRT_YOLO_NMS(torch.autograd.Function): + '''TensorRT NMS operation''' + @staticmethod + def forward( + ctx, + boxes, + scores, + background_class=-1, + box_coding=1, + iou_threshold=0.45, + max_output_boxes=100, + plugin_version="1", + score_activation=0, + score_threshold=0.25, + class_agnostic=0, + ): + + batch_size, num_boxes, num_classes = scores.shape + num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) + det_boxes = torch.randn(batch_size, max_output_boxes, 4) + det_scores = torch.randn(batch_size, max_output_boxes) + det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) + det_indices = torch.randint(0,num_boxes,(batch_size, max_output_boxes), dtype=torch.int32) + return num_det, det_boxes, det_scores, det_classes, det_indices + + @staticmethod + def symbolic(g, + boxes, + scores, + background_class=-1, + box_coding=1, + iou_threshold=0.45, + max_output_boxes=100, + plugin_version="1", + score_activation=0, + score_threshold=0.25, + class_agnostic=0): + out = g.op("TRT::YOLO_NMS_TRT", + boxes, + scores, + background_class_i=background_class, + box_coding_i=box_coding, + iou_threshold_f=iou_threshold, + max_output_boxes_i=max_output_boxes, + plugin_version_s=plugin_version, + score_activation_i=score_activation, + class_agnostic_i=class_agnostic, + score_threshold_f=score_threshold, + outputs=5) + nums, boxes, scores, classes, det_indices = out + return nums, boxes, scores, classes, det_indices + +class TRT_ROIAlign(torch.autograd.Function): + @staticmethod + def forward( + ctx, + X, + rois, + batch_indices, + coordinate_transformation_mode= 1, + mode=1, # 1- avg pooling / 0 - max pooling + output_height=160, + output_width=160, + sampling_ratio=0, + spatial_scale=0.25, + ): + device = rois.device + dtype = rois.dtype + N, C, H, W = X.shape + num_rois = rois.shape[0] + return torch.randn((num_rois, C, output_height, output_width), device=device, dtype=dtype) + + @staticmethod + def symbolic( + g, + X, + rois, + batch_indices, + coordinate_transformation_mode=1, + mode=1, + output_height=160, + output_width=160, + sampling_ratio=0, + spatial_scale=0.25, + ): + return g.op( + "TRT::ROIAlign_TRT", + X, + rois, + batch_indices, + coordinate_transformation_mode_i=coordinate_transformation_mode, + mode_i=mode, + output_height_i=output_height, + output_width_i=output_width, + sampling_ratio_i=sampling_ratio, + spatial_scale_f=spatial_scale, + ) + +class ONNX_YOLO_TRT(nn.Module): + '''onnx module with TensorRT NMS operation.''' + def __init__(self, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): + super().__init__() + assert max_wh is None + self.device = device if device else torch.device('cpu') + self.class_agnostic = 1 if class_agnostic else 0 + self.background_class = -1, + self.box_coding = 1, + self.iou_threshold = iou_thres + self.max_obj = max_obj + self.plugin_version = '1' + self.score_activation = 0 + self.score_threshold = score_thres + self.n_classes=n_classes + + + def forward(self, x): + if isinstance(x, list): + x = x[1] + x = x.permute(0, 2, 1) + bboxes_x = x[..., 0:1] + bboxes_y = x[..., 1:2] + bboxes_w = x[..., 2:3] + bboxes_h = x[..., 3:4] + bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) + bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] + obj_conf = x[..., 4:] + scores = obj_conf + num_det, det_boxes, det_scores, det_classes, det_indices = TRT_YOLO_NMS.apply(bboxes, scores, self.background_class, self.box_coding, + self.iou_threshold, self.max_obj, + self.plugin_version, self.score_activation, + self.score_threshold, self.class_agnostic) + return num_det, det_boxes, det_scores, det_classes, det_indices + +class End2End_TRT(nn.Module): + '''export onnx or tensorrt model with NMS operation.''' + def __init__(self, model, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, mask_resolution=56, pooler_scale=0.25, sampling_ratio=0, max_wh=None, device=None, n_classes=80, is_det_model=True): + super().__init__() + device = device if device else torch.device('cpu') + assert isinstance(max_wh,(int)) or max_wh is None + self.model = model.to(device) + self.model.model[-1].end2end = True + if is_det_model: + self.patch_model = ONNX_YOLO_TRT + self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, max_wh, device, n_classes) + else: + self.patch_model = ONNX_YOLO_MASK_TRT + self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, mask_resolution, pooler_scale, sampling_ratio, max_wh, device, n_classes) + self.end2end.eval() + + def forward(self, x): + x = self.model(x) + x = self.end2end(x) + return x + + +class ONNX_YOLO_MASK_TRT(nn.Module): + """onnx module with ONNX-TensorRT NMS/ROIAlign operation.""" + def __init__( + self, + class_agnostic=False, + max_obj=100, + iou_thres=0.45, + score_thres=0.25, + mask_resolution=160, + pooler_scale=0.25, + sampling_ratio=0, + max_wh=None, + device=None, + n_classes=80 + ): + super().__init__() + assert isinstance(max_wh,(int)) or max_wh is None + self.device = device if device else torch.device('cpu') + self.class_agnostic = 1 if class_agnostic else 0 + self.max_obj = max_obj + self.background_class = -1, + self.box_coding = 1, + self.iou_threshold = iou_thres + self.max_obj = max_obj + self.plugin_version = '1' + self.score_activation = 0 + self.score_threshold = score_thres + self.n_classes=n_classes + self.mask_resolution = mask_resolution + self.pooler_scale = pooler_scale + self.sampling_ratio = sampling_ratio + + def forward(self, x): + if isinstance(x, list): ## remove auxiliary branch + x = x[1] + det=x[0] + proto=x[1] + det = det.permute(0, 2, 1) + + bboxes_x = det[..., 0:1] + bboxes_y = det[..., 1:2] + bboxes_w = det[..., 2:3] + bboxes_h = det[..., 3:4] + bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) + bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] + scores = det[..., 4: 4 + self.n_classes] + + batch_size, nm, proto_h, proto_w = proto.shape + total_object = batch_size * self.max_obj + masks = det[..., 4 + self.n_classes : 4 + self.n_classes + nm] + num_det, det_boxes, det_scores, det_classes, det_indices = TRT_YOLO_NMS.apply(bboxes, scores, self.background_class, self.box_coding, + self.iou_threshold, self.max_obj, + self.plugin_version, self.score_activation, + self.score_threshold,self.class_agnostic) + + batch_indices = torch.ones_like(det_indices) * torch.arange(batch_size, device=self.device, dtype=torch.int32).unsqueeze(1) + batch_indices = batch_indices.view(total_object).to(torch.long) + det_indices = det_indices.view(total_object).to(torch.long) + det_masks = masks[batch_indices, det_indices] + + + pooled_proto = TRT_ROIAlign.apply( proto, + det_boxes.view(total_object, 4), + batch_indices, + 1, + 1, + self.mask_resolution, + self.mask_resolution, + self.sampling_ratio, + self.pooler_scale + ) + pooled_proto = pooled_proto.view( + total_object, nm, self.mask_resolution * self.mask_resolution, + ) + + det_masks = ( + torch.matmul(det_masks.unsqueeze(dim=1), pooled_proto) + .sigmoid() + .view(batch_size, self.max_obj, self.mask_resolution * self.mask_resolution) + ) + + return num_det, det_boxes, det_scores, det_classes, det_masks