Skip to content

Commit

Permalink
Add pre-commit and add TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Aug 17, 2023
1 parent d1cb809 commit 035d103
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 349 deletions.
83 changes: 44 additions & 39 deletions coco_2_labelImg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,83 +10,85 @@
from tqdm import tqdm


class COCO2labelImg():
class COCO2labelImg:
def __init__(self, data_dir: str = None):
# coco dir
self.data_dir = Path(data_dir)
self.verify_exists(self.data_dir)

anno_dir = self.data_dir / 'annotations'
anno_dir = self.data_dir / "annotations"
self.verify_exists(anno_dir)

self.train_json = anno_dir / 'instances_train2017.json'
self.val_json = anno_dir / 'instances_val2017.json'
self.train_json = anno_dir / "instances_train2017.json"
self.val_json = anno_dir / "instances_val2017.json"
self.verify_exists(self.train_json)
self.verify_exists(self.val_json)

self.train2017_dir = self.data_dir / 'train2017'
self.val2017_dir = self.data_dir / 'val2017'
self.train2017_dir = self.data_dir / "train2017"
self.val2017_dir = self.data_dir / "val2017"
self.verify_exists(self.train2017_dir)
self.verify_exists(self.val2017_dir)

# save dir
self.save_dir = self.data_dir.parent / 'COCO_labelImg_format'
self.save_dir = self.data_dir.parent / "COCO_labelImg_format"
self.mkdir(self.save_dir)

self.save_train_dir = self.save_dir / 'train'
self.save_train_dir = self.save_dir / "train"
self.mkdir(self.save_train_dir)

self.save_val_dir = self.save_dir / 'val'
self.save_val_dir = self.save_dir / "val"
self.mkdir(self.save_val_dir)

def __call__(self, ):
def __call__(
self,
):
train_list = [self.train_json, self.save_train_dir, self.train2017_dir]
self.convert(train_list)

val_list = [self.val_json, self.save_val_dir, self.val2017_dir]
self.convert(val_list)

print(f'Successfully convert, detail in {self.save_dir}')
print(f"Successfully convert, detail in {self.save_dir}")

def convert(self, info_list: list):
json_path, save_dir, img_dir = info_list

data = self.read_json(str(json_path))
self.gen_classes_txt(save_dir, data.get('categories'))
self.gen_classes_txt(save_dir, data.get("categories"))

id_img_dict = {v['id']: v for v in data.get('images')}
all_annotaions = data.get('annotations')
id_img_dict = {v["id"]: v for v in data.get("images")}
all_annotaions = data.get("annotations")
for one_anno in tqdm(all_annotaions):
image_info = id_img_dict.get(one_anno['image_id'])
img_name = image_info.get('file_name')
img_height = image_info.get('height')
img_width = image_info.get('width')
image_info = id_img_dict.get(one_anno["image_id"])
img_name = image_info.get("file_name")
img_height = image_info.get("height")
img_width = image_info.get("width")

seg_info = one_anno.get('segmentation')
seg_info = one_anno.get("segmentation")
if seg_info:
bbox = self.get_bbox(seg_info)
xywh = self.xyxy_to_xywh(bbox, img_width, img_height)
category_id = int(one_anno.get('category_id')) - 1
xywh_str = ' '.join([str(v) for v in xywh])
label_str = f'{category_id} {xywh_str}'
category_id = int(one_anno.get("category_id")) - 1
xywh_str = " ".join([str(v) for v in xywh])
label_str = f"{category_id} {xywh_str}"

# 写入标注的txt文件
txt_full_path = save_dir / f'{Path(img_name).stem}.txt'
self.write_txt(txt_full_path, label_str, mode='a')
txt_full_path = save_dir / f"{Path(img_name).stem}.txt"
self.write_txt(txt_full_path, label_str, mode="a")

# 复制图像到转换后目录
img_full_path = img_dir / img_name
shutil.copy2(img_full_path, save_dir)

@staticmethod
def read_json(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data

def gen_classes_txt(self, save_dir, categories_dict):
class_info = [value['name'] for value in categories_dict]
self.write_txt(save_dir / 'classes.txt', class_info)
class_info = [value["name"] for value in categories_dict]
self.write_txt(save_dir / "classes.txt", class_info)

def get_bbox(self, seg_info):
seg_info = np.array(seg_info[0]).reshape(4, 2)
Expand All @@ -96,20 +98,20 @@ def get_bbox(self, seg_info):
return bbox

@staticmethod
def write_txt(save_path: str, content: list, mode='w'):
def write_txt(save_path: str, content: list, mode="w"):
if not isinstance(save_path, str):
save_path = str(save_path)

if isinstance(content, str):
content = [content]
with open(save_path, mode, encoding='utf-8') as f:
with open(save_path, mode, encoding="utf-8") as f:
for value in content:
f.write(f'{value}\n')
f.write(f"{value}\n")

@staticmethod
def xyxy_to_xywh(xyxy: list,
img_width: int,
img_height: int) -> tuple([float, float, float, float]):
def xyxy_to_xywh(
xyxy: list, img_width: int, img_height: int
) -> tuple([float, float, float, float]):
"""
xyxy: (list), [x1, y1, x2, y2]
"""
Expand All @@ -127,18 +129,21 @@ def xyxy_to_xywh(xyxy: list,
def verify_exists(file_path):
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f'The {file_path} is not exists!!!')
raise FileNotFoundError(f"The {file_path} is not exists!!!")

@staticmethod
def mkdir(dir_path):
Path(dir_path).mkdir(parents=True, exist_ok=True)


if __name__ == '__main__':
parser = argparse.ArgumentParser('Datasets convert from COCO to labelImg')
parser.add_argument('--data_dir', type=str,
default='dataset/YOLOV5_COCO_format',
help='Dataset root path')
if __name__ == "__main__":
parser = argparse.ArgumentParser("Datasets convert from COCO to labelImg")
parser.add_argument(
"--data_dir",
type=str,
default="dataset/YOLOV5_COCO_format",
help="Dataset root path",
)
args = parser.parse_args()

converter = COCO2labelImg(args.data_dir)
Expand Down
80 changes: 46 additions & 34 deletions coco_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,68 +10,80 @@


def visualization_bbox(num_image, json_path, img_path):
with open(json_path, 'r', encoding='utf-8') as annos:
with open(json_path, "r", encoding="utf-8") as annos:
annotation_json = json.load(annos)
print('The annotation_json num_key is:', len(annotation_json))
print('The annotation_json key is:', annotation_json.keys())
print('The annotation_json num_images is:', len(annotation_json['images']))
print("The annotation_json num_key is:", len(annotation_json))
print("The annotation_json key is:", annotation_json.keys())
print("The annotation_json num_images is:", len(annotation_json["images"]))

categories = annotation_json['categories']
categories_dict = {c['id']: c['name'] for c in categories}
categories = annotation_json["categories"]
categories_dict = {c["id"]: c["name"] for c in categories}
class_nums = len(categories_dict.keys())
color = [(random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255)) for _ in range(class_nums)]
color = [
(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
for _ in range(class_nums)
]

image_name = annotation_json['images'][num_image - 1]['file_name']
img_id = annotation_json['images'][num_image - 1]['id']
image_name = annotation_json["images"][num_image - 1]["file_name"]
img_id = annotation_json["images"][num_image - 1]["id"]
image_path = os.path.join(img_path, str(image_name).zfill(5))
image = cv2.imread(image_path, 1)

annotations = annotation_json['annotations']
annotations = annotation_json["annotations"]
num_bbox = 0
for anno in annotations:
if anno['image_id'] == img_id:
if anno["image_id"] == img_id:
num_bbox = num_bbox + 1

class_id = anno['category_id']
class_id = anno["category_id"]
class_name = categories_dict[class_id]
class_color = color[class_id-1]
class_color = color[class_id - 1]

x, y, w, h = list(map(int, anno['bbox']))
cv2.rectangle(image, (int(x), int(y)),
(int(x + w), int(y + h)),
class_color, 2)
x, y, w, h = list(map(int, anno["bbox"]))
cv2.rectangle(
image, (int(x), int(y)), (int(x + w), int(y + h)), class_color, 2
)

font_size = 0.7
txt_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX,
font_size, 1)[0]
cv2.rectangle(image, (x, y + 1),
(x + txt_size[0] + 10, y - int(2 * txt_size[1])),
class_color, -1)
cv2.putText(image, class_name, (x + 5, y - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_size, (255, 255, 255), 1)
txt_size = cv2.getTextSize(
class_name, cv2.FONT_HERSHEY_SIMPLEX, font_size, 1
)[0]
cv2.rectangle(
image,
(x, y + 1),
(x + txt_size[0] + 10, y - int(2 * txt_size[1])),
class_color,
-1,
)
cv2.putText(
image,
class_name,
(x + 5, y - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_size,
(255, 255, 255),
1,
)

print('The unm_bbox of the display image is:', num_bbox)
print("The unm_bbox of the display image is:", num_bbox)

cur_os = platform.system()
if cur_os == 'Windows':
if cur_os == "Windows":
cv2.namedWindow(image_name, 0)
cv2.resizeWindow(image_name, 1000, 1000)
cv2.imshow(image_name, image)
cv2.waitKey(0)
else:
save_path = f'visul_{num_image}.jpg'
save_path = f"visul_{num_image}.jpg"
cv2.imwrite(save_path, image)
print(f'The {save_path} has been saved the current director.')
print(f"The {save_path} has been saved the current director.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--vis_num', type=int, default=1,
help="visual which one")
parser.add_argument('--json_path', type=str, required=True)
parser.add_argument('--img_dir', type=str, required=True)
parser.add_argument("--vis_num", type=int, default=1, help="visual which one")
parser.add_argument("--json_path", type=str, required=True)
parser.add_argument("--img_dir", type=str, required=True)
args = parser.parse_args()

visualization_bbox(args.vis_num, args.json_path, args.img_dir)
Loading

0 comments on commit 035d103

Please sign in to comment.