-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
137 lines (104 loc) · 4.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import _init_paths
import os
import argparse
import tqdm
import cv2
import torch
from annotator import lane_ann, car_ann
from pre_processor import lane_prx2, car_prx2
from car_detector.config import CarConfig
from car_detector.model import CarDetector
from lane_detector.config import LaneConfig
from lane_detector.model import LaneDetector
def parse_args():
"""Argument Parser"""
parser = argparse.ArgumentParser(description="Car Lane Joint Detection")
parser.add_argument("-m", "--mode", choices=["image", "video"], default="image")
parser.add_argument("--fps", type=int, default=20, help='registered frames-per-second for videos')
parser.add_argument("-dp", "--data_path", default="data/images",
help="path to an image directory or a explicit path to a video")
parser.add_argument("-lcf", "--lane_cfg_path", default="cfgs/lane.yml",
help="Path to lane-model-config file")
parser.add_argument("-ccf", "--car_cfg_path", default="cfgs/car.yml",
help="Path to car-model-config file")
parser.add_argument("-odr", "--out_dir", default="output", help="Saving directory")
return parser.parse_args()
def main():
args = parse_args()
# Create the output dir
if not os.path.exists(args.out_dir):
os.mkdir(args.out_dir)
# Config torch device
device = torch.device('cuda')
# Load a pre-trained lane detection model
print('Loading lane detection model and its configuration: ', args.lane_cfg_path)
lane_cfg = LaneConfig(args.lane_cfg_path)
lane = LaneDetector(lane_cfg, device)
# Load a pre-trained car detection model
print('Loading car detection model and its configuration: ', args.lane_cfg_path)
car_cfg = CarConfig(args.car_cfg_path)
car = CarDetector(car_cfg, device)
if args.mode == "image":
# Load a list of images
print('Loading images: ', args.data_path)
images = os.listdir(args.data_path)
# Run car, lane detections
print('Running detection on images ...')
for idx, item in enumerate(tqdm.tqdm(images)):
# Read image
im = cv2.imread(os.path.join(args.data_path, item))
# Process the image for lane detection
lane_im = lane_prx2(im, lane_cfg['model']['parameters']['img_h'],
lane_cfg['model']['parameters']['img_w'])
# Process the image for car detection
car_im = car_prx2(im, car_cfg['im_size'])
# Running detection on the processed image
lane_pred = lane.detect(lane_im)[0]
car_pred = car.detect(car_im)[0]
# Annotate the prediction
ann_im, lines = lane_ann(im, lane_pred)
ann_im = car_ann(ann_im, car_pred, car_cfg, car.coco, lines)
# Save prediction
cv2.imwrite(args.out_dir + '/sav_' + images[idx], ann_im)
if args.mode == "video":
# Load a video and get its size
print('Loading video: ', args.data_path)
cap = cv2.VideoCapture(args.data_path)
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
size = (frame_width, frame_height)
# Create output-video writer
print('Creating a video writer')
out = cv2.VideoWriter('output/sav_' + args.data_path.split('/')[-1].split('.')[0] + '.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
args.fps, size)
print('Running detection on video ...')
while cap.isOpened():
ret, frame = cap.read()
if ret:
# Process frame for lane detection
lane_frame = lane_prx2(frame, lane_cfg['model']['parameters']['img_h'],
lane_cfg['model']['parameters']['img_w'])
# Process frame for car detection
car_frame = car_prx2(frame, car_cfg['im_size'])
# Running detection on the processed image
lane_pred = lane.detect(lane_frame)[0]
car_pred = car.detect(car_frame)[0]
# Annotate the prediction
ann_frame, lines = lane_ann(frame, lane_pred)
ann_frame = car_ann(ann_frame, car_pred, car_cfg, car.coco, lines)
# write the annotated frame
out.write(ann_frame)
# cv2.imshow('frame', lane_frame) # Un-comment three lines to observe processed videos
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
else:
break
# Release everything if job is finished
cap.release()
out.release()
cv2.destroyAllWindows()
print('Saving results to: ', args.out_dir)
print('Finish !')
if __name__ == '__main__':
main()