From 624aaf4959925fb13457df210446a0622218ca9a Mon Sep 17 00:00:00 2001 From: Manuel Kraus <58955879+makra89@users.noreply.github.com> Date: Thu, 2 Sep 2021 09:18:09 +0200 Subject: [PATCH] Correct score for one-class detection --- yolov3_tf2/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/yolov3_tf2/models.py b/yolov3_tf2/models.py index c3d699f8..6bc0008f 100644 --- a/yolov3_tf2/models.py +++ b/yolov3_tf2/models.py @@ -197,7 +197,11 @@ def yolo_nms(outputs, anchors, masks, classes): confidence = tf.concat(c, axis=1) class_probs = tf.concat(t, axis=1) - scores = confidence * class_probs + # If we only have one class, do not multiply by class_prob (always 0.5) + if classes == 1: + scores = confidence + else: + scores = confidence * class_probs dscores = tf.squeeze(scores, axis=0) scores = tf.reduce_max(dscores,[1])