diff --git a/convertOnnxToTensorRT.py b/convertOnnxToTensorRT.py index 708f464..3c313ae 100644 --- a/convertOnnxToTensorRT.py +++ b/convertOnnxToTensorRT.py @@ -9,8 +9,8 @@ # import pycuda.autoinit # import numpy as np # import cv2 as cv2 +# from ObjectDetector.utils import Scaler -# from ObjectDetector.yoloDetector import YoloDetector """ takes in onnx model converts to tensorrt @@ -19,32 +19,44 @@ parser = argparse.ArgumentParser(description='https://github.com/jason-li-831202/Vehicle-CV-ADAS') parser.add_argument('--input_onnx_model', '-i', default="./ObjectDetector/models/yolov8m-coco_fp16.onnx", type=str, help='Onnx model path.') parser.add_argument('--output_trt_model', '-o', default="./ObjectDetector/models/yolov8m-coco_fp16.trt", type=str, help='Tensorrt model path.') +# parser.add_argument("--calib_image_dir", default=None, type=Path, help="The calibrate data required for conversion to int8, if None will use dynamic quantization") parser.add_argument('--verbose', action='store_true', default=False, help='TensorRT: verbose log') FILE = Path(__file__).resolve() ROOT = FILE.parents[1] -# class Calibrator(trt.IInt8EntropyCalibrator): -# def __init__(self, quantification=1, batch_size=1, height=640, width=640, calibration_images="", cache_file=""): -# trt.IInt8EntropyCalibrator.__init__(self) -# self.index = 0 -# self.length = quantification +# class Calibrator(trt.IInt8MinMaxCalibrator): +# def __init__(self, batch_size=1, height=640, width=640, calibration_images="", cache_file=""): +# trt.IInt8MinMaxCalibrator.__init__(self) +# self.batch_idx = 0 # self.batch_size = batch_size # self.cache_file = cache_file # self.height = height # self.width = width + # self.img_list = [ str(name) for name in Path(calibration_images).iterdir()] -# self.calibration_data = np.zeros((self.batch_size, 3, self.height, self.width), dtype=np.float32) -# self.d_input = drv.mem_alloc(self.calibration_data.nbytes) +# self.max_batch_idx = len(self.img_list) // self.batch_size +# self.data_size = trt.volume([self.batch_size, 3, self.height, self.width]) * trt.float32.itemsize +# self.batch_allocation = drv.mem_alloc(self.data_size) +# self.scaler = Scaler((self.height, self.width), keep_ratio=True) +# print('Found all {} images to calib.'.format(len(self.img_list))) +# def preprocess(self, img): +# image = self.scaler.process_image(img) +# # TODO : for yolov5/6/7/8 = 1/255.0, for yolox = 1. +# image = cv2.dnn.blobFromImage(image, 1, (image.shape[1], image.shape[0]), swapRB=True, crop=False).astype(np.float32) +# return image + # def next_batch(self): -# if self.index < self.length: -# for i in range(self.batch_size): -# img = cv2.imread(self.img_list[i + self.index*self.batch_size]) +# if self.batch_idx < self.max_batch_idx: +# batch_files = self.img_list[self.batch_idx * self.batch_size: (self.batch_idx + 1) * self.batch_size] +# batch_imgs = np.zeros((self.batch_size, 3, self.height, self.width), dtype=np.float32) +# for i, f in enumerate(batch_files): +# img = cv2.imread(f) # (h, w, c) # img = self.preprocess(img) -# self.calibration_data[i] = img -# self.index += 1 -# return np.ascontiguousarray(self.calibration_data, dtype=np.float32) +# batch_imgs[i] = img +# self.batch_idx += 1 +# return batch_imgs # else: # return np.array([]) @@ -52,29 +64,39 @@ # return self.length # def get_batch_size(self): +# """ +# Overrides from trt.IInt8EntropyCalibrator2. +# Get the batch size to use for calibration. +# :return: Batch size. +# """ # return self.batch_size # def get_batch(self, name): # batch = self.next_batch() +# print("Calibrating image {} / {}".format(self.batch_idx, self.max_batch_idx )) # if not batch.size: # return None -# drv.memcpy_htod(self.d_input, batch) -# return [int(self.d_input)] +# drv.memcpy_htod(self.batch_allocation, batch) +# return [int(self.batch_allocation)] # def read_calibration_cache(self): -# # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. +# """ +# Overrides from trt.IInt8EntropyCalibrator2. +# Read the calibration cache file stored on disk, if it exists. +# :return: The contents of the cache file, if any. +# """ # if Path(self.cache_file).exists(): # with open(self.cache_file, "rb") as f: # return f.read() # def write_calibration_cache(self, cache): +# """ +# Overrides from trt.IInt8EntropyCalibrator2. +# Store the calibration cache to a file on disk. +# :param cache: The contents of the calibration cache to store. +# """ # with open(self.cache_file, "wb") as f: # f.write(cache) - -# def preprocess(self, img): -# image, newh, neww, ratioh, ratiow, padh, padw = YoloDetector.resize_image_format(img, (self.height, self.width), True) -# image = cv2.dnn.blobFromImage(image, 1/255.0, (image.shape[1], image.shape[0]), swapRB=True, crop=False).astype(np.float32) -# return image class EngineBuilder: """ @@ -125,14 +147,17 @@ def create_network(self, onnx_model_path : str): for out in outputs: print(self.colorstr('bright_magenta', f' Output "{out.name}" with shape {out.shape} and dtype {out.dtype}')) - def create_engine(self, trt_model_path : str): + def create_engine(self, trt_model_path : str, calib_image_path: Optional[str] = None): start = time.time() inp = [self.network.get_input(i) for i in range(self.network.num_inputs)][0] print(f' Note: building FP{16 if (self.builder.platform_has_fast_fp16 and inp.dtype==trt.DataType.HALF) else 32} engine as {Path(trt_model_path).resolve()}') if self.builder.platform_has_fast_fp16 and inp.dtype==trt.DataType.HALF: self.config.set_flag(trt.BuilderFlag.FP16) - # self.config.set_flag(trt.BuilderFlag.INT8) - # self.config.int8_calibrator = Calibrator(1, 1, inp.shape[2], inp.shape[3], "./demo/val2017") + # if calib_image_path != None and (calib_image_path).is_dir(): + # # Also enable fp16, as some layers may be even more efficient in fp16 than int8 + # self.config.set_flag(trt.BuilderFlag.FP16) + # self.config.set_flag(trt.BuilderFlag.INT8) + # self.config.int8_calibrator = Calibrator(1, inp.shape[2], inp.shape[3], calib_image_path) print(self.colorstr('magenta', "*"*40)) print(self.colorstr('👉 Building the TensorRT engine. This would take a while...'))