Skip to content

Commit

Permalink
Merge pull request #46 from augmentedstartups/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
augmentedstartups authored Jun 13, 2023
2 parents 9fc7c46 + 7db57f4 commit b5fa8d9
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 109 deletions.
102 changes: 98 additions & 4 deletions asone/asone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def __init__(self,
weights: str = None,
use_cuda: bool = True,
recognizer: int = None,
languages: list = ['en']
languages: list = ['en'],
num_classes=80
) -> None:

self.use_cuda = use_cuda

# get detector object
self.detector = self.get_detector(detector, weights, recognizer)
self.detector = self.get_detector(detector, weights, recognizer, num_classes)
self.recognizer = self.get_recognizer(recognizer, languages=languages)

if tracker == -1:
Expand All @@ -33,9 +34,9 @@ def __init__(self,

self.tracker = self.get_tracker(tracker)

def get_detector(self, detector: int, weights: str, recognizer):
def get_detector(self, detector: int, weights: str, recognizer, num_classes):
detector = Detector(detector, weights=weights,
use_cuda=self.use_cuda, recognizer=recognizer).get_detector()
use_cuda=self.use_cuda, recognizer=recognizer, num_classes=num_classes).get_detector()
return detector

def get_recognizer(self, recognizer: int, languages):
Expand Down Expand Up @@ -85,6 +86,99 @@ def track_video(self,
# yeild bbox_details, frame_details to main script
yield bbox_details, frame_details

def detect_video(self,
video_path,
**kwargs
):
output_filename = os.path.basename(video_path)
kwargs['filename'] = output_filename
config = self._update_args(kwargs)

# os.makedirs(output_path, exist_ok=True)

fps = config.pop('fps')
output_dir = config.pop('output_dir')
filename = config.pop('filename')
save_result = config.pop('save_result')
display = config.pop('display')
draw_trails = config.pop('draw_trails')
class_names = config.pop('class_names')

cap = cv2.VideoCapture(video_path)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)

if fps is None:
fps = cap.get(cv2.CAP_PROP_FPS)

if save_result:
os.makedirs(output_dir, exist_ok=True)
save_path = os.path.join(output_dir, filename)
logger.info(f"video save path is {save_path}")

video_writer = cv2.VideoWriter(
save_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(int(width), int(height)),
)

frame_id = 1
tic = time.time()

prevTime = 0
frame_no = 0
while True:
start_time = time.time()

ret, img = cap.read()
if not ret:
break
frame = img.copy()

dets, img_info = self.detector.detect(img, conf_thres=0.25, iou_thres=0.45)
currTime = time.time()
fps = 1 / (currTime - prevTime)
prevTime = currTime

if dets is not None:
bbox_xyxy = dets[:, :4]
scores = dets[:, 4]
class_ids = dets[:, 5]
img = utils.draw_boxes(img, bbox_xyxy, class_ids=class_ids, class_names=class_names)

cv2.line(img, (20, 25), (127, 25), [85, 45, 255], 30)
cv2.putText(img, f'FPS: {int(fps)}', (11, 35), 0, 1, [
225, 255, 255], thickness=2, lineType=cv2.LINE_AA)


elapsed_time = time.time() - start_time

logger.info(
'frame {}/{} ({:.2f} ms)'.format(frame_no, int(frame_count),
elapsed_time * 1000))
frame_no+=1
if display:
cv2.imshow('Window', img)

if save_result:
video_writer.write(img)

if cv2.waitKey(25) & 0xFF == ord('q'):
break

yield (bbox_xyxy, scores, class_ids), (im0 if display else frame, frame_no-1, fps)

tac = time.time()
print(f'Total Time Taken: {tac - tic:.2f}')
# kwargs['filename'] = output_filename
# config = self._update_args(kwargs)

# for (bbox_details, frame_details) in self._start_tracking(video_path, config):
# # yeild bbox_details, frame_details to main script
# yield bbox_details, frame_details

def detect(self, source, **kwargs)->np.ndarray:
""" Function to perform detection on an img
Expand Down
110 changes: 39 additions & 71 deletions asone/demo_detector.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,65 @@
import sys
import argparse
import asone
from asone import ASOne
from .utils import draw_boxes
import cv2
import argparse
import time
import os
import sys
import torch


def main(args):
filter_classes = args.filter_classes
video_path = args.video

os.makedirs(args.output_path, exist_ok=True)

if filter_classes:
filter_classes = filter_classes.split(',')

filter_classes = ['person']
# Check if cuda available
if args.use_cuda and torch.cuda.is_available():
args.use_cuda = True
else:
args.use_cuda = False

if sys.platform.startswith('darwin'):
detector = asone.YOLOV7_MLMODEL
else:
detector = asone.YOLOV7_PYTORCH

detector = ASOne(detector, weights=args.weights, use_cuda=args.use_cuda)

cap = cv2.VideoCapture(video_path)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
FPS = cap.get(cv2.CAP_PROP_FPS)

if args.save:
video_writer = cv2.VideoWriter(
os.path.basename(video_path),
cv2.VideoWriter_fourcc(*"mp4v"),
FPS,
(int(width), int(height)),

detect = ASOne(
detector=detector,
weights=args.weights,
use_cuda=args.use_cuda
)
# Get tracking function
track = detect.detect_video(args.video_path,
output_dir=args.output_dir,
conf_thres=args.conf_thres,
iou_thres=args.iou_thres,
display=args.display,
filter_classes=filter_classes,
class_names=None) # class_names=['License Plate'] for custom weights

frame_no = 1
tic = time.time()

prevTime = 0

while True:
start_time = time.time()

ret, img = cap.read()
if not ret:
break
frame = img.copy()
# Loop over track_fn to retrieve outputs of each frame
for bbox_details, frame_details in track:
bbox_xyxy, scores, class_ids = bbox_details
frame, frame_num, fps = frame_details
print(frame_num)

dets, img_info = detector.detect(img, conf_thres=0.25, iou_thres=0.45)
currTime = time.time()
fps = 1 / (currTime - prevTime)
prevTime = currTime

if dets is not None:
bbox_xyxy = dets[:, :4]
scores = dets[:, 4]
class_ids = dets[:, 5]
img = draw_boxes(img, bbox_xyxy, class_ids=class_ids)

cv2.line(img, (20, 25), (127, 25), [85, 45, 255], 30)
cv2.putText(img, f'FPS: {int(fps)}', (11, 35), 0, 1, [
225, 255, 255], thickness=2, lineType=cv2.LINE_AA)


frame_no+=1
if args.display:
cv2.imshow('Window', img)

if args.save:
video_writer.write(img)

if cv2.waitKey(25) & 0xFF == ord('q'):
break

if __name__=='__main__':

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("video", help="Path of video")
parser.add_argument('--cpu', default=True, action='store_false', dest='use_cuda', help='If provided the model will run on cpu otherwise it will run on gpu')
parser.add_argument('--filter_classes', default=None, help='Class names seperated by comma (,). e.g. person,car ')

parser.add_argument('video_path', help='Path to input video')
parser.add_argument('--cpu', default=True, action='store_false', dest='use_cuda',
help='run on cpu if not provided the program will run on gpu.')
parser.add_argument('--no_save', default=True, action='store_false',
dest='save_result', help='whether or not save results')
parser.add_argument('--no_display', default=True, action='store_false',
dest='display', help='whether or not display results on screen')
parser.add_argument('--output_dir', default='data/results', help='Path to output directory')
parser.add_argument('--draw_trails', action='store_true', default=False,
help='if provided object motion trails will be drawn.')
parser.add_argument('--filter_classes', default=None, help='Filter class name')
parser.add_argument('-w', '--weights', default=None, help='Path of trained weights')
parser.add_argument('-o', '--output_path', default='data/results', help='path of output file')
parser.add_argument('--no_display', action='store_false', default=True, dest='display', help='if provided video will not be displayed')
parser.add_argument('--no_save', action='store_false', default=True, dest='save', help='if provided video will not be saved')
parser.add_argument('-ct', '--conf_thres', default=0.25, type=float, help='confidence score threshold')
parser.add_argument('-it', '--iou_thres', default=0.45, type=float, help='iou score threshold')

args = parser.parse_args()

main(args)
16 changes: 10 additions & 6 deletions asone/detectors/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self,
model_flag: int,
weights: str = None,
use_cuda: bool = True,
recognizer:int = None):
recognizer:int = None,
num_classes=80):

self.model = self._select_detector(model_flag, weights, use_cuda, recognizer)
def _select_detector(self, model_flag, weights, cuda, recognizer):
self.model = self._select_detector(model_flag, weights, use_cuda, recognizer, num_classes)
def _select_detector(self, model_flag, weights, cuda, recognizer, num_classes):
# Get required weight using model_flag
mlmodel = False
if weights and weights.split('.')[-1] == 'onnx':
Expand Down Expand Up @@ -101,9 +102,12 @@ def _select_detector(self, model_flag, weights, cuda, recognizer):
use_cuda=cuda)
elif model_flag in range(160, 163):
# Get exp file and corresponding model for coreml only
_detector = YOLOnasDetector(weights=weight,
use_onnx=onnx,
use_cuda=cuda)
_detector = YOLOnasDetector(
model_flag,
weights=weight,
use_onnx=onnx,
use_cuda=cuda,
num_classes=num_classes)

return _detector

Expand Down
Loading

0 comments on commit b5fa8d9

Please sign in to comment.