diff --git a/yolov3_tf2/models.py b/yolov3_tf2/models.py index e08ebfd9..f7cc7b3c 100644 --- a/yolov3_tf2/models.py +++ b/yolov3_tf2/models.py @@ -303,9 +303,12 @@ def yolo_loss(y_true, y_pred): obj_loss = binary_crossentropy(true_obj, pred_obj) obj_loss = obj_mask * obj_loss + \ (1 - obj_mask) * ignore_mask * obj_loss - # TODO: use binary_crossentropy instead - class_loss = obj_mask * sparse_categorical_crossentropy( - true_class_idx, pred_class) + # sparse_categorical_crossentropy will always output 0 when number of classes is 1, + # so convert true_class into one hot label and use binary_crossentropy. + true_class_one_hot = tf.one_hot( + tf.cast(true_class_idx[..., 0], tf.int32), classes) + class_loss = obj_mask * binary_crossentropy( + true_class_one_hot, pred_class) # 6. sum over (batch, gridx, gridy, anchors) => (batch, 1) xy_loss = tf.reduce_sum(xy_loss, axis=(1, 2, 3))