Skip to content

Commit a44399d

Browse files
committed
refactoring detectors
1 parent e547d4d commit a44399d

24 files changed

+165
-227
lines changed

giskard_vision/core/dataloaders/base.py

-10
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,9 @@
55
import numpy as np
66

77
from giskard_vision.core.dataloaders.meta import MetaData
8-
from giskard_vision.core.detectors.base import IssueGroup
98

109
from ..types import TypesBase
1110

12-
EthicalIssueMeta = IssueGroup(
13-
"Ethical",
14-
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
15-
)
16-
PerformanceIssueMeta = IssueGroup(
17-
"Performance",
18-
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
19-
)
20-
2111

2212
class DataIteratorBase(ABC):
2313
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.

giskard_vision/core/dataloaders/meta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional
22

3-
from giskard_vision.core.detectors.base import IssueGroup
3+
from giskard_vision.core.issues import IssueGroup
44

55

66
class MetaData:

giskard_vision/core/detectors/base.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
from dataclasses import dataclass
33
from typing import Any, List, Optional, Sequence, Tuple
44

5+
from giskard_vision.core.issues import IssueGroup
56
from giskard_vision.utils.errors import GiskardImportError
67

7-
8-
@dataclass(frozen=True)
9-
class IssueGroup:
10-
name: str
11-
description: str
8+
from .specs import DetectorSpecsBase
129

1310

1411
@dataclass
@@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
5148
}
5249

5350

54-
class DetectorVisionBase:
51+
class DetectorVisionBase(DetectorSpecsBase):
5552
"""
5653
Abstract class for Vision Detectors
5754
@@ -67,12 +64,6 @@ class DetectorVisionBase:
6764
evaluation results for the scan.
6865
"""
6966

70-
issue_group: IssueGroup
71-
warning_messages: dict
72-
issue_level_threshold: float = 0.2
73-
deviation_threshold: float = 0.05
74-
num_images: int = 0
75-
7667
def run(
7768
self,
7869
model: Any,
@@ -139,6 +130,42 @@ def get_issues(
139130

140131
return issues
141132

133+
def get_scan_result(
134+
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
135+
) -> ScanResult:
136+
try:
137+
from giskard.scanner.issues import IssueLevel
138+
except (ImportError, ModuleNotFoundError) as e:
139+
raise GiskardImportError(["giskard"]) from e
140+
141+
relative_delta = metric_value - metric_reference_value
142+
if self.metric_type == "relative":
143+
relative_delta /= metric_reference_value
144+
145+
issue_level = IssueLevel.MINOR
146+
if self.metric_direction == "better_lower":
147+
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
148+
issue_level = IssueLevel.MAJOR
149+
elif relative_delta > self.issue_level_threshold:
150+
issue_level = IssueLevel.MEDIUM
151+
elif self.metric_direction == "better_higher":
152+
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
153+
issue_level = IssueLevel.MAJOR
154+
elif relative_delta < -self.issue_level_threshold:
155+
issue_level = IssueLevel.MEDIUM
156+
157+
return ScanResult(
158+
name=name,
159+
metric_name=metric_name,
160+
metric_value=metric_value,
161+
metric_reference_value=metric_reference_value,
162+
issue_level=issue_level,
163+
slice_size=size_data,
164+
filename_examples=filename_examples,
165+
relative_delta=relative_delta,
166+
issue_group=issue_group,
167+
)
168+
142169
@abstractmethod
143170
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
144171
"""Returns a list of ScanResult

giskard_vision/core/detectors/metadata_scan_detector.py

+2-41
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
import numpy as np
55
import pandas as pd
66

7-
from giskard_vision.core.detectors.base import (
8-
DetectorVisionBase,
9-
IssueGroup,
10-
ScanResult,
11-
)
7+
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
8+
from giskard_vision.core.issues import IssueGroup
129
from giskard_vision.core.tests.base import MetricBase
1310
from giskard_vision.utils.errors import GiskardImportError
1411

@@ -263,39 +260,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
263260
pass
264261

265262
return pd.DataFrame(df)
266-
267-
def get_scan_result(
268-
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
269-
) -> ScanResult:
270-
try:
271-
from giskard.scanner.issues import IssueLevel
272-
except (ImportError, ModuleNotFoundError) as e:
273-
raise GiskardImportError(["giskard"]) from e
274-
275-
relative_delta = metric_value - metric_reference_value
276-
if self.metric_type == "relative":
277-
relative_delta /= metric_reference_value
278-
279-
issue_level = IssueLevel.MINOR
280-
if self.metric_direction == "better_lower":
281-
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
282-
issue_level = IssueLevel.MAJOR
283-
elif relative_delta > self.issue_level_threshold:
284-
issue_level = IssueLevel.MEDIUM
285-
elif self.metric_direction == "better_higher":
286-
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
287-
issue_level = IssueLevel.MAJOR
288-
elif relative_delta < -self.issue_level_threshold:
289-
issue_level = IssueLevel.MEDIUM
290-
291-
return ScanResult(
292-
name=name,
293-
metric_name=metric_name,
294-
metric_value=metric_value,
295-
metric_reference_value=metric_reference_value,
296-
issue_level=issue_level,
297-
slice_size=size_data,
298-
filename_examples=filename_examples,
299-
relative_delta=relative_delta,
300-
issue_group=issue_group,
301-
)

giskard_vision/core/detectors/metrics.py

-9
This file was deleted.
+40-46
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
11
import os
22
from abc import abstractmethod
3+
from importlib import import_module
34
from pathlib import Path
4-
from typing import Any, Sequence
5+
from typing import Any, Sequence, Tuple
56

67
import cv2
78

89
from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
9-
from giskard_vision.core.detectors.base import (
10-
DetectorVisionBase,
11-
IssueGroup,
12-
ScanResult,
13-
)
10+
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
11+
from giskard_vision.core.issues import Robustness
1412
from giskard_vision.core.tests.base import TestDiffBase
15-
from giskard_vision.utils.errors import GiskardImportError
16-
17-
from .metrics import detector_metrics
18-
19-
Robustness = IssueGroup(
20-
"Robustness",
21-
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
22-
)
2313

2414

2515
class PerturbationBaseDetector(DetectorVisionBase):
@@ -40,6 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):
4030

4131
issue_group = Robustness
4232

33+
def run(
34+
self,
35+
model: Any,
36+
dataset: Any,
37+
features: Any | None = None,
38+
issue_levels: Tuple[Any] = None,
39+
embed: bool = True,
40+
num_images: int = 0,
41+
) -> Sequence[Any]:
42+
module = import_module(f"giskard_vision.{model.model_type}.detectors.specs")
43+
DetectorSpecs = getattr(module, "DetectorSpecs")
44+
45+
if DetectorSpecs:
46+
# Only set attributes that are not part of Python's special attributes (those starting with __)
47+
for attr_name, attr_value in vars(DetectorSpecs).items():
48+
if not attr_name.startswith("__") and hasattr(self, attr_name):
49+
setattr(self, attr_name, attr_value)
50+
else:
51+
raise ValueError(f"No detector specifications found for model type: {model.model_type}")
52+
53+
return super().run(model, dataset, features, issue_levels, embed, num_images)
54+
4355
@abstractmethod
4456
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...
4557

@@ -48,7 +60,7 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
4860

4961
results = []
5062
for dl in dataloaders:
51-
test_result = TestDiffBase(metric=detector_metrics[model.model_type], threshold=1).run(
63+
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
5264
model=model,
5365
dataloader=dl,
5466
dataloader_ref=dataset,
@@ -63,40 +75,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
6375

6476
if isinstance(dl, FilteredDataLoader):
6577
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
66-
cv2.imwrite(
67-
filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)
68-
)
78+
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
6979
filename_examples.append(filename_example_dataloader_ref)
7080

7181
filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
72-
cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3))
82+
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
7383
filename_examples.append(filename_example_dataloader)
74-
results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl)))
84+
results.append(
85+
self.get_scan_result(
86+
test_result.metric_value_test,
87+
test_result.metric_value_test,
88+
test_result.metric_name,
89+
filename_examples,
90+
dl.name,
91+
len(dl),
92+
issue_group=self.issue_group,
93+
)
94+
)
7595

7696
return results
77-
78-
def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult:
79-
try:
80-
from giskard.scanner.issues import IssueLevel
81-
except (ImportError, ModuleNotFoundError) as e:
82-
raise GiskardImportError(["giskard"]) from e
83-
84-
relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref
85-
86-
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
87-
issue_level = IssueLevel.MAJOR
88-
elif relative_delta > self.issue_level_threshold:
89-
issue_level = IssueLevel.MEDIUM
90-
else:
91-
issue_level = IssueLevel.MINOR
92-
93-
return ScanResult(
94-
name=name,
95-
metric_name=test_result.metric_name,
96-
metric_value=test_result.metric_value_test,
97-
metric_reference_value=test_result.metric_value_ref,
98-
issue_level=issue_level,
99-
slice_size=size_data,
100-
filename_examples=filename_examples,
101-
relative_delta=relative_delta,
102-
)
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from giskard_vision.core.issues import IssueGroup
2+
from giskard_vision.image_classification.tests.performance import MetricBase
3+
4+
5+
class DetectorSpecsBase:
6+
issue_group: IssueGroup
7+
warning_messages: dict
8+
metric: MetricBase = None
9+
metric_type: str = None
10+
metric_direction: str = None
11+
deviation_threshold: float = 0.10
12+
issue_level_threshold: float = 0.05
13+
num_images: int = 0

giskard_vision/core/detectors/transformation_color_detector.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader
22

33
from ...core.detectors.decorator import maybe_detector
4-
from .perturbation import PerturbationBaseDetector, Robustness
4+
from .perturbation import PerturbationBaseDetector
55

66

77
@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
@@ -10,8 +10,6 @@ class TransformationColorDetectorLandmark(PerturbationBaseDetector):
1010
Detector that evaluates models performance depending on images in grayscale
1111
"""
1212

13-
issue_group = Robustness
14-
1513
def get_dataloaders(self, dataset):
1614
dl = ColoredDataLoader(dataset)
1715

giskard_vision/core/detectors/transformation_noise_detector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .perturbation import PerturbationBaseDetector
55

66

7-
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
7+
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "noise"])
88
class TransformationNoiseDetectorLandmark(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance on noisy images

giskard_vision/core/issues.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(frozen=True)
5+
class IssueGroup:
6+
name: str
7+
description: str
8+
9+
10+
EthicalIssueMeta = IssueGroup(
11+
"Ethical",
12+
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
13+
)
14+
PerformanceIssueMeta = IssueGroup(
15+
"Performance",
16+
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
17+
)
18+
Robustness = IssueGroup(
19+
"Robustness",
20+
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
21+
)

giskard_vision/image_classification/dataloaders/loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import numpy as np
44

5-
from giskard_vision.core.dataloaders.base import EthicalIssueMeta, PerformanceIssueMeta
65
from giskard_vision.core.dataloaders.hf import HFDataLoader
76
from giskard_vision.core.dataloaders.meta import MetaData
87
from giskard_vision.core.dataloaders.tfds import DataLoaderTensorFlowDatasets
98
from giskard_vision.core.dataloaders.utils import flatten_dict
9+
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
1010
from giskard_vision.image_classification.types import Types
1111

1212

Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
from giskard_vision.core.detectors.metadata_scan_detector import MetaDataScanDetector
2-
from giskard_vision.image_classification.tests.performance import Accuracy
32

43
from ...core.detectors.decorator import maybe_detector
4+
from .specs import DetectorSpecs
55

66

77
@maybe_detector("metadata_classification", tags=["vision", "image_classification", "metadata"])
8-
class MetaDataScanDetectorClassification(MetaDataScanDetector):
9-
metric = Accuracy
10-
type_task = "classification"
11-
metric_type = "absolute"
12-
metric_direction = "better_higher"
13-
deviation_threshold = 0.10
14-
issue_level_threshold = 0.05
8+
class MetaDataScanDetectorClassification(DetectorSpecs, MetaDataScanDetector):
9+
pass

0 commit comments

Comments
 (0)