diff --git a/yolov3_tf2/models.py b/yolov3_tf2/models.py index c3d699f8..6bc0008f 100644 --- a/yolov3_tf2/models.py +++ b/yolov3_tf2/models.py @@ -197,7 +197,11 @@ def yolo_nms(outputs, anchors, masks, classes): confidence = tf.concat(c, axis=1) class_probs = tf.concat(t, axis=1) - scores = confidence * class_probs + # If we only have one class, do not multiply by class_prob (always 0.5) + if classes == 1: + scores = confidence + else: + scores = confidence * class_probs dscores = tf.squeeze(scores, axis=0) scores = tf.reduce_max(dscores,[1])