Skip to content

Commit ffbb425

Browse files
authored
Merge pull request #53 from Giskard-AI/refactoring-detectors
refactoring detectors
2 parents 6601ecb + 6ba1994 commit ffbb425

29 files changed

+178
-230
lines changed

giskard_vision/core/dataloaders/base.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,10 @@
1212
get_image_channel_number,
1313
get_image_size,
1414
)
15-
from giskard_vision.core.detectors.base import IssueGroup
15+
from giskard_vision.core.issues import AttributesIssueMeta
1616

1717
from ..types import TypesBase
1818

19-
EthicalIssueMeta = IssueGroup(
20-
"Ethical",
21-
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
22-
)
23-
PerformanceIssueMeta = IssueGroup(
24-
"Performance",
25-
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
26-
)
27-
AttributesIssueMeta = IssueGroup(
28-
"Attributes",
29-
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
30-
)
31-
3219

3320
class DataIteratorBase(ABC):
3421
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.

giskard_vision/core/dataloaders/hf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from PIL.Image import Image as PILImage
99

10-
from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
10+
from giskard_vision.core.dataloaders.base import DataIteratorBase
1111
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
12+
from giskard_vision.core.issues import AttributesIssueMeta
1213
from giskard_vision.utils.errors import GiskardError, GiskardImportError
1314

1415

giskard_vision/core/dataloaders/meta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from PIL.Image import Image as PILImage
66

7-
from giskard_vision.core.detectors.base import IssueGroup
7+
from giskard_vision.core.issues import IssueGroup
88

99

1010
class MetaData:

giskard_vision/core/detectors/base.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
from dataclasses import dataclass
33
from typing import Any, List, Optional, Sequence, Tuple
44

5+
from giskard_vision.core.issues import IssueGroup
56
from giskard_vision.utils.errors import GiskardImportError
67

7-
8-
@dataclass(frozen=True)
9-
class IssueGroup:
10-
name: str
11-
description: str
8+
from .specs import DetectorSpecsBase
129

1310

1411
@dataclass
@@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
5148
}
5249

5350

54-
class DetectorVisionBase:
51+
class DetectorVisionBase(DetectorSpecsBase):
5552
"""
5653
Abstract class for Vision Detectors
5754
@@ -67,12 +64,6 @@ class DetectorVisionBase:
6764
evaluation results for the scan.
6865
"""
6966

70-
issue_group: IssueGroup
71-
warning_messages: dict
72-
issue_level_threshold: float = 0.2
73-
deviation_threshold: float = 0.05
74-
num_images: int = 0
75-
7667
def run(
7768
self,
7869
model: Any,
@@ -139,6 +130,42 @@ def get_issues(
139130

140131
return issues
141132

133+
def get_scan_result(
134+
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
135+
) -> ScanResult:
136+
try:
137+
from giskard.scanner.issues import IssueLevel
138+
except (ImportError, ModuleNotFoundError) as e:
139+
raise GiskardImportError(["giskard"]) from e
140+
141+
relative_delta = metric_value - metric_reference_value
142+
if self.metric_type == "relative":
143+
relative_delta /= metric_reference_value
144+
145+
issue_level = IssueLevel.MINOR
146+
if self.metric_direction == "better_lower":
147+
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
148+
issue_level = IssueLevel.MAJOR
149+
elif relative_delta > self.issue_level_threshold:
150+
issue_level = IssueLevel.MEDIUM
151+
elif self.metric_direction == "better_higher":
152+
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
153+
issue_level = IssueLevel.MAJOR
154+
elif relative_delta < -self.issue_level_threshold:
155+
issue_level = IssueLevel.MEDIUM
156+
157+
return ScanResult(
158+
name=name,
159+
metric_name=metric_name,
160+
metric_value=metric_value,
161+
metric_reference_value=metric_reference_value,
162+
issue_level=issue_level,
163+
slice_size=size_data,
164+
filename_examples=filename_examples,
165+
relative_delta=relative_delta,
166+
issue_group=issue_group,
167+
)
168+
142169
@abstractmethod
143170
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
144171
"""Returns a list of ScanResult

giskard_vision/core/detectors/metadata_scan_detector.py

+1-37
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import numpy as np
55
import pandas as pd
66

7-
from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
87
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
8+
from giskard_vision.core.issues import PerformanceIssueMeta
99
from giskard_vision.core.tests.base import MetricBase
1010
from giskard_vision.utils.errors import GiskardImportError
1111

@@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
258258
pass
259259

260260
return pd.DataFrame(df)
261-
262-
def get_scan_result(
263-
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
264-
) -> ScanResult:
265-
try:
266-
from giskard.scanner.issues import IssueLevel
267-
except (ImportError, ModuleNotFoundError) as e:
268-
raise GiskardImportError(["giskard"]) from e
269-
270-
relative_delta = metric_value - metric_reference_value
271-
if self.metric_type == "relative":
272-
relative_delta /= metric_reference_value
273-
274-
issue_level = IssueLevel.MINOR
275-
if self.metric_direction == "better_lower":
276-
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
277-
issue_level = IssueLevel.MAJOR
278-
elif relative_delta > self.issue_level_threshold:
279-
issue_level = IssueLevel.MEDIUM
280-
elif self.metric_direction == "better_higher":
281-
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
282-
issue_level = IssueLevel.MAJOR
283-
elif relative_delta < -self.issue_level_threshold:
284-
issue_level = IssueLevel.MEDIUM
285-
286-
return ScanResult(
287-
name=name,
288-
metric_name=metric_name,
289-
metric_value=metric_value,
290-
metric_reference_value=metric_reference_value,
291-
issue_level=issue_level,
292-
slice_size=size_data,
293-
filename_examples=filename_examples,
294-
relative_delta=relative_delta,
295-
issue_group=issue_group,
296-
)

giskard_vision/core/detectors/metrics.py

-9
This file was deleted.
+30-45
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
11
import os
22
from abc import abstractmethod
3+
from importlib import import_module
34
from pathlib import Path
45
from typing import Any, Sequence
56

67
import cv2
78

89
from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
9-
from giskard_vision.core.detectors.base import (
10-
DetectorVisionBase,
11-
IssueGroup,
12-
ScanResult,
13-
)
10+
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
11+
from giskard_vision.core.issues import Robustness
1412
from giskard_vision.core.tests.base import TestDiffBase
15-
from giskard_vision.utils.errors import GiskardImportError
16-
17-
from .metrics import detector_metrics
18-
19-
Robustness = IssueGroup(
20-
"Robustness",
21-
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
22-
)
2313

2414

2515
class PerturbationBaseDetector(DetectorVisionBase):
@@ -40,15 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):
4030

4131
issue_group = Robustness
4232

33+
def set_specs_from_model_type(self, model_type):
34+
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
35+
DetectorSpecs = getattr(module, "DetectorSpecs")
36+
37+
if DetectorSpecs:
38+
# Only set attributes that are not part of Python's special attributes (those starting with __)
39+
for attr_name, attr_value in vars(DetectorSpecs).items():
40+
if not attr_name.startswith("__") and hasattr(self, attr_name):
41+
setattr(self, attr_name, attr_value)
42+
else:
43+
raise ValueError(f"No detector specifications found for model type: {model_type}")
44+
4345
@abstractmethod
4446
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...
4547

4648
def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
49+
self.set_specs_from_model_type(model.model_type)
4750
dataloaders = self.get_dataloaders(dataset)
4851

4952
results = []
5053
for dl in dataloaders:
51-
test_result = TestDiffBase(metric=detector_metrics[model.model_type], threshold=1).run(
54+
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
5255
model=model,
5356
dataloader=dl,
5457
dataloader_ref=dataset,
@@ -63,40 +66,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
6366

6467
if isinstance(dl, FilteredDataLoader):
6568
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
66-
cv2.imwrite(
67-
filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)
68-
)
69+
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
6970
filename_examples.append(filename_example_dataloader_ref)
7071

7172
filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
72-
cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3))
73+
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
7374
filename_examples.append(filename_example_dataloader)
74-
results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl)))
75+
results.append(
76+
self.get_scan_result(
77+
test_result.metric_value_test,
78+
test_result.metric_value_ref,
79+
test_result.metric_name,
80+
filename_examples,
81+
dl.name,
82+
len(dl),
83+
issue_group=self.issue_group,
84+
)
85+
)
7586

7687
return results
77-
78-
def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult:
79-
try:
80-
from giskard.scanner.issues import IssueLevel
81-
except (ImportError, ModuleNotFoundError) as e:
82-
raise GiskardImportError(["giskard"]) from e
83-
84-
relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref
85-
86-
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
87-
issue_level = IssueLevel.MAJOR
88-
elif relative_delta > self.issue_level_threshold:
89-
issue_level = IssueLevel.MEDIUM
90-
else:
91-
issue_level = IssueLevel.MINOR
92-
93-
return ScanResult(
94-
name=name,
95-
metric_name=test_result.metric_name,
96-
metric_value=test_result.metric_value_test,
97-
metric_reference_value=test_result.metric_value_ref,
98-
issue_level=issue_level,
99-
slice_size=size_data,
100-
filename_examples=filename_examples,
101-
relative_delta=relative_delta,
102-
)
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from giskard_vision.core.issues import IssueGroup
2+
from giskard_vision.image_classification.tests.performance import MetricBase
3+
4+
5+
class DetectorSpecsBase:
6+
issue_group: IssueGroup
7+
warning_messages: dict
8+
metric: MetricBase = None
9+
metric_type: str = None
10+
metric_direction: str = None
11+
deviation_threshold: float = 0.10
12+
issue_level_threshold: float = 0.05
13+
num_images: int = 0

giskard_vision/core/detectors/transformation_blurring_detector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
8-
class TransformationBlurringDetectorLandmark(PerturbationBaseDetector):
8+
class TransformationBlurringDetector(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance on blurred images
1111
"""

giskard_vision/core/detectors/transformation_color_detector.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader
22

33
from ...core.detectors.decorator import maybe_detector
4-
from .perturbation import PerturbationBaseDetector, Robustness
4+
from .perturbation import PerturbationBaseDetector
55

66

77
@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
8-
class TransformationColorDetectorLandmark(PerturbationBaseDetector):
8+
class TransformationColorDetector(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance depending on images in grayscale
1111
"""
1212

13-
issue_group = Robustness
14-
1513
def get_dataloaders(self, dataset):
1614
dl = ColoredDataLoader(dataset)
1715

giskard_vision/core/detectors/transformation_noise_detector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from .perturbation import PerturbationBaseDetector
55

66

7-
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
8-
class TransformationNoiseDetectorLandmark(PerturbationBaseDetector):
7+
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "noise"])
8+
class TransformationNoiseDetector(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance on noisy images
1111
"""

giskard_vision/core/issues.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(frozen=True)
5+
class IssueGroup:
6+
name: str
7+
description: str
8+
9+
10+
EthicalIssueMeta = IssueGroup(
11+
"Ethical",
12+
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
13+
)
14+
PerformanceIssueMeta = IssueGroup(
15+
"Performance",
16+
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
17+
)
18+
AttributesIssueMeta = IssueGroup(
19+
"Attributes",
20+
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
21+
)
22+
Robustness = IssueGroup(
23+
"Robustness",
24+
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
25+
)

giskard_vision/image_classification/dataloaders/loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
from PIL.Image import Image as PILImage
55

6-
from giskard_vision.core.dataloaders.base import EthicalIssueMeta, PerformanceIssueMeta
76
from giskard_vision.core.dataloaders.hf import HFDataLoader
87
from giskard_vision.core.dataloaders.meta import MetaData
98
from giskard_vision.core.dataloaders.tfds import DataLoaderTensorFlowDatasets
109
from giskard_vision.core.dataloaders.utils import flatten_dict
10+
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
1111
from giskard_vision.image_classification.types import Types
1212

1313

0 commit comments

Comments
 (0)