From a7472f1eb4a40e64f2985793fe9e2e54740d3094 Mon Sep 17 00:00:00 2001 From: SWHL Date: Wed, 8 May 2024 21:55:07 +0800 Subject: [PATCH] Fixed issue #11 --- label_convert/vis_coco.py | 138 +++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 55 deletions(-) diff --git a/label_convert/vis_coco.py b/label_convert/vis_coco.py index f748b37..f7a9ecc 100644 --- a/label_convert/vis_coco.py +++ b/label_convert/vis_coco.py @@ -3,9 +3,9 @@ # @Contact: liekkaskono@163.com import argparse import json -import platform import random from pathlib import Path +from typing import List, Tuple import cv2 import numpy as np @@ -17,30 +17,25 @@ def __init__( ): self.font_size = 0.7 - def __call__(self, img_id: int, json_path, img_path): - with open(json_path, "r", encoding="utf-8") as annos: - anno_dict = json.load(annos) - + def __call__(self, img_id: int, json_path: str, img_path: str): + anno_dict = self.read_json(json_path) anno_imgs = anno_dict.get("images", None) if anno_imgs is None: raise ValueError(f"The images of {json_path} cannot be empty.") - print("The anno_dict num_key is:", len(anno_dict)) - print("The anno_dict key is:", anno_dict.keys()) - print("The anno_dict num_images is:", len(anno_imgs)) + print(f"The anno_dict num_key is: {len(anno_dict)}") + print(f"The anno_dict key is: {anno_dict.keys()}") + print(f"The anno_dict num_images is: {len(anno_imgs)}") categories = anno_dict["categories"] categories_dict = {c["id"]: c["name"] for c in categories} class_nums = len(categories_dict.keys()) - color = [ - (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) - for _ in range(class_nums) - ] + color = self.get_class_color(class_nums) img_info = anno_dict["images"][img_id - 1] - img_name = img_info.get("file_name") + img_name = img_info.get("file_name") img_full_path = Path(img_path) / img_name image = cv2.imread(str(img_full_path)) @@ -57,50 +52,83 @@ def __call__(self, img_id: int, json_path, img_path): class_name = categories_dict[class_id] class_color = color[class_id - 1] - # plot sgmentations segs = anno.get("segmentation", None) if segs is not None: - segs = np.array(segs).reshape(-1, 2) - cv2.polylines(image, np.int32([segs]), 2, class_color) - - # plot rectangle - x, y, w, h = [round(v) for v in anno["bbox"]] - cv2.rectangle( - image, (int(x), int(y)), (int(x + w), int(y + h)), class_color, 2 - ) - - txt_size = cv2.getTextSize( - class_name, cv2.FONT_HERSHEY_SIMPLEX, self.font_size, 1 - )[0] - cv2.rectangle( - image, - (x, y + 1), - (x + txt_size[0] + 5, y - int(1.5 * txt_size[1])), - class_color, - -1, - ) - cv2.putText( - image, - class_name, - (x + 5, y - 5), - cv2.FONT_HERSHEY_SIMPLEX, - self.font_size, - (255, 255, 255), - 1, - ) - - print("The unm_bbox of the display image is:", num_bbox) - - cur_os = platform.system() - if cur_os == "Windows": - cv2.namedWindow(img_name, 0) - cv2.resizeWindow(img_name, 1000, 1000) - cv2.imshow(img_name, image) - cv2.waitKey(0) - else: - save_path = f"vis_{Path(img_name).stem}.jpg" - cv2.imwrite(save_path, image) - print(f"The {save_path} has been saved the current director.") + self.plot_segmentations(image, segs, class_color) + self.plot_text(image, segs[0][:2], class_color, class_name) + + bbox = anno.get("bbox", None) + if bbox is None: + continue + + self.plot_rectangle(image, bbox, class_color) + self.plot_text(image, bbox, class_color, class_name) + + print(f"The unm_bbox of the display image is: {num_bbox}") + save_path = f"vis_{Path(img_name).stem}.jpg" + cv2.imwrite(save_path, image) + print(f"The {save_path} has been saved the current director.") + + @staticmethod + def read_json(json_path): + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data + + @staticmethod + def get_class_color(class_nums: int) -> List[Tuple[int]]: + def random_color(): + return random.randint(0, 255) + + color = [ + (random_color(), random_color(), random_color()) for _ in range(class_nums) + ] + return color + + @staticmethod + def plot_segmentations( + image: np.ndarray, segs: List[List[float]], class_color: Tuple[int] + ): + segs = np.array(segs).reshape(-1, 2) + cv2.polylines(image, np.int32([segs]), 2, class_color) + + @staticmethod + def plot_rectangle( + image: np.ndarray, + bbox: List[float], + class_color: Tuple[int], + thickness: int = 1, + ): + x, y, w, h = [round(v) for v in bbox] + start_point = (int(x), int(y)) + end_point = (int(x + w), int(y + h)) + cv2.rectangle(image, start_point, end_point, class_color, thickness) + + def plot_text( + self, + image: np.ndarray, + bbox: Tuple[float], + class_color: str, + class_name: str, + ): + txt_size = cv2.getTextSize( + class_name, cv2.FONT_HERSHEY_SIMPLEX, self.font_size, 1 + )[0] + + x, y = [round(v) for v in bbox[:2]] + start_point = (x, y + 1) + end_point = (x + txt_size[0] + 5, y - int(1.5 * txt_size[1])) + cv2.rectangle(image, start_point, end_point, class_color, -1) + + cv2.putText( + image, + class_name, + (x + 5, y - 5), + cv2.FONT_HERSHEY_SIMPLEX, + self.font_size, + (255, 255, 255), + 1, + ) def main():