Skip to content

Commit

Permalink
Merge pull request #781 from roboflow/multi_label_top_class_confidence
Browse files Browse the repository at this point in the history
Return single top class confidence for mulit-label predictions
  • Loading branch information
PawelPeczek-Roboflow authored Nov 6, 2024
2 parents 4e922b3 + ba89e28 commit 8c4fdd4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SequenceAggregationMode(Enum):
class ClassificationProperty(Enum):
TOP_CLASS = "top_class"
TOP_CLASS_CONFIDENCE = "top_class_confidence"
TOP_CLASS_CONFIDENCE_SINGLE = "top_class_confidence_single"
ALL_CLASSES = "all_classes"
ALL_CONFIDENCES = "all_confidences"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def extract_top_class_confidence(prediction: dict) -> Union[float, List[float]]:
]


def extract_top_class_confidence_single(prediction: dict) -> Union[float, List[float]]:
if "confidence" in prediction:
return prediction["confidence"]
predicted_classes = prediction.get("predicted_classes", [])
predicted_confidences = [
prediction["predictions"][class_name]["confidence"]
for class_name in predicted_classes
]
if not predicted_confidences:
return 0.0
return max(predicted_confidences)


def extract_all_class_names(prediction: dict) -> List[str]:
predictions = prediction["predictions"]
if isinstance(predictions, list):
Expand All @@ -52,6 +65,7 @@ def extract_all_classes_confidence(prediction: dict) -> List[float]:
CLASSIFICATION_PROPERTY_EXTRACTORS = {
ClassificationProperty.TOP_CLASS: extract_top_class,
ClassificationProperty.TOP_CLASS_CONFIDENCE: extract_top_class_confidence,
ClassificationProperty.TOP_CLASS_CONFIDENCE_SINGLE: extract_top_class_confidence_single,
ClassificationProperty.ALL_CLASSES: extract_all_class_names,
ClassificationProperty.ALL_CONFIDENCES: extract_all_classes_confidence,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,58 @@ def test_classification_result_extraction_of_top_class_confidence_for_multi_labe
assert result == [0.4]


def test_classification_result_extraction_of_top_class_confidence_single_for_multi_label_classification_result_when_class_detected() -> (
None
):
# given
operations = [
{
"type": "ClassificationPropertyExtract",
"property_name": "top_class_confidence_single",
}
]
data = MultiLabelClassificationInferenceResponse(
image=InferenceResponseImage(width=128, height=256),
predictions={
"cat": MultiLabelClassificationPrediction(class_id=0, confidence=0.6),
"dog": MultiLabelClassificationPrediction(class_id=1, confidence=0.4),
},
predicted_classes=["cat", "dog"],
).dict(by_alias=True, exclude_none=True)

# when
result = execute_operations(value=data, operations=operations)

# then
assert result == 0.6


def test_classification_result_extraction_of_top_class_confidence_single_for_multi_label_classification_result_when_no_classes_detected() -> (
None
):
# given
operations = [
{
"type": "ClassificationPropertyExtract",
"property_name": "top_class_confidence_single",
}
]
data = MultiLabelClassificationInferenceResponse(
image=InferenceResponseImage(width=128, height=256),
predictions={
"cat": MultiLabelClassificationPrediction(class_id=0, confidence=0.6),
"dog": MultiLabelClassificationPrediction(class_id=1, confidence=0.4),
},
predicted_classes=[],
).dict(by_alias=True, exclude_none=True)

# when
result = execute_operations(value=data, operations=operations)

# then
assert result == 0.0


def test_classification_result_extraction_of_all_classes_for_multi_class_classification_result() -> (
None
):
Expand Down

0 comments on commit 8c4fdd4

Please sign in to comment.