Skip to content

Commit b9ab6b8

Browse files
authored
Merge pull request #51 from Giskard-AI/perturbation-detectors
working on perturbation detectors
2 parents 677cf94 + ffbb425 commit b9ab6b8

33 files changed

+403
-179
lines changed

giskard_vision/core/dataloaders/base.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,10 @@
1212
get_image_channel_number,
1313
get_image_size,
1414
)
15-
from giskard_vision.core.detectors.base import IssueGroup
15+
from giskard_vision.core.issues import AttributesIssueMeta
1616

1717
from ..types import TypesBase
1818

19-
EthicalIssueMeta = IssueGroup(
20-
"Ethical",
21-
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
22-
)
23-
PerformanceIssueMeta = IssueGroup(
24-
"Performance",
25-
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
26-
)
27-
AttributesIssueMeta = IssueGroup(
28-
"Attributes",
29-
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
30-
)
31-
3219

3320
class DataIteratorBase(ABC):
3421
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.

giskard_vision/core/dataloaders/hf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from PIL.Image import Image as PILImage
99

10-
from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
10+
from giskard_vision.core.dataloaders.base import DataIteratorBase
1111
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
12+
from giskard_vision.core.issues import AttributesIssueMeta
1213
from giskard_vision.utils.errors import GiskardError, GiskardImportError
1314

1415

giskard_vision/core/dataloaders/meta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from PIL.Image import Image as PILImage
66

7-
from giskard_vision.core.detectors.base import IssueGroup
7+
from giskard_vision.core.issues import IssueGroup
88

99

1010
class MetaData:

giskard_vision/core/dataloaders/wrappers.py

+72
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,78 @@ def get_image(self, idx: int) -> np.ndarray:
228228
return cv2.GaussianBlur(image, self._kernel_size, *self._sigma)
229229

230230

231+
class NoisyDataLoader(DataLoaderWrapper):
232+
"""Wrapper class for a DataIteratorBase, providing noisy images.
233+
234+
Args:
235+
dataloader (DataIteratorBase): The data loader to be wrapped.
236+
sigma (float): Standard deviation of the Gaussian noise.
237+
238+
Returns:
239+
NoisyDataLoader: Noisy data loader instance.
240+
"""
241+
242+
def __init__(
243+
self,
244+
dataloader: DataIteratorBase,
245+
sigma: float = 0.1,
246+
) -> None:
247+
"""
248+
Initializes the BlurredDataLoader.
249+
250+
Args:
251+
dataloader (DataIteratorBase): The data loader to be wrapped.
252+
sigma (float): Standard deviation of the Gaussian noise.
253+
"""
254+
super().__init__(dataloader)
255+
self._sigma = sigma
256+
257+
@property
258+
def name(self):
259+
"""
260+
Gets the name of the blurred data loader.
261+
262+
Returns:
263+
str: The name of the blurred data loader.
264+
"""
265+
return "noisy"
266+
267+
def get_image(self, idx: int) -> np.ndarray:
268+
"""
269+
Gets a blurred image using Gaussian blur.
270+
271+
Args:
272+
idx (int): Index of the data.
273+
274+
Returns:
275+
np.ndarray: Blurred image data.
276+
"""
277+
image = super().get_image(idx)
278+
return self.add_gaussian_noise(image, self._sigma * 255)
279+
280+
def add_gaussian_noise(self, image, std_dev):
281+
"""
282+
Add Gaussian noise to the image
283+
284+
Args:
285+
image (np.ndarray): Image
286+
std_dev (float): Standard deviation of the Gaussian noise.
287+
288+
Returns:
289+
np.ndarray: Noisy image
290+
"""
291+
# Generate Gaussian noise
292+
noise = np.random.normal(0, std_dev, image.shape).astype(np.float32)
293+
294+
# Add the noise to the image
295+
noisy_image = cv2.add(image.astype(np.float32), noise)
296+
297+
# Clip the values to stay within valid range (0-255 for uint8)
298+
noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)
299+
300+
return noisy_image
301+
302+
231303
class ColoredDataLoader(DataLoaderWrapper):
232304
"""Wrapper class for a DataIteratorBase, providing color-altered images using OpenCV color conversion.
233305

giskard_vision/core/detectors/base.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
from dataclasses import dataclass
33
from typing import Any, List, Optional, Sequence, Tuple
44

5+
from giskard_vision.core.issues import IssueGroup
56
from giskard_vision.utils.errors import GiskardImportError
67

7-
8-
@dataclass(frozen=True)
9-
class IssueGroup:
10-
name: str
11-
description: str
8+
from .specs import DetectorSpecsBase
129

1310

1411
@dataclass
@@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
5148
}
5249

5350

54-
class DetectorVisionBase:
51+
class DetectorVisionBase(DetectorSpecsBase):
5552
"""
5653
Abstract class for Vision Detectors
5754
@@ -67,12 +64,6 @@ class DetectorVisionBase:
6764
evaluation results for the scan.
6865
"""
6966

70-
issue_group: IssueGroup
71-
warning_messages: dict
72-
issue_level_threshold: float = 0.2
73-
deviation_threshold: float = 0.05
74-
num_images: int = 0
75-
7667
def run(
7768
self,
7869
model: Any,
@@ -139,6 +130,42 @@ def get_issues(
139130

140131
return issues
141132

133+
def get_scan_result(
134+
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
135+
) -> ScanResult:
136+
try:
137+
from giskard.scanner.issues import IssueLevel
138+
except (ImportError, ModuleNotFoundError) as e:
139+
raise GiskardImportError(["giskard"]) from e
140+
141+
relative_delta = metric_value - metric_reference_value
142+
if self.metric_type == "relative":
143+
relative_delta /= metric_reference_value
144+
145+
issue_level = IssueLevel.MINOR
146+
if self.metric_direction == "better_lower":
147+
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
148+
issue_level = IssueLevel.MAJOR
149+
elif relative_delta > self.issue_level_threshold:
150+
issue_level = IssueLevel.MEDIUM
151+
elif self.metric_direction == "better_higher":
152+
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
153+
issue_level = IssueLevel.MAJOR
154+
elif relative_delta < -self.issue_level_threshold:
155+
issue_level = IssueLevel.MEDIUM
156+
157+
return ScanResult(
158+
name=name,
159+
metric_name=metric_name,
160+
metric_value=metric_value,
161+
metric_reference_value=metric_reference_value,
162+
issue_level=issue_level,
163+
slice_size=size_data,
164+
filename_examples=filename_examples,
165+
relative_delta=relative_delta,
166+
issue_group=issue_group,
167+
)
168+
142169
@abstractmethod
143170
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
144171
"""Returns a list of ScanResult

giskard_vision/core/detectors/metadata_scan_detector.py

+1-37
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import numpy as np
55
import pandas as pd
66

7-
from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
87
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
8+
from giskard_vision.core.issues import PerformanceIssueMeta
99
from giskard_vision.core.tests.base import MetricBase
1010
from giskard_vision.utils.errors import GiskardImportError
1111

@@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
258258
pass
259259

260260
return pd.DataFrame(df)
261-
262-
def get_scan_result(
263-
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
264-
) -> ScanResult:
265-
try:
266-
from giskard.scanner.issues import IssueLevel
267-
except (ImportError, ModuleNotFoundError) as e:
268-
raise GiskardImportError(["giskard"]) from e
269-
270-
relative_delta = metric_value - metric_reference_value
271-
if self.metric_type == "relative":
272-
relative_delta /= metric_reference_value
273-
274-
issue_level = IssueLevel.MINOR
275-
if self.metric_direction == "better_lower":
276-
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
277-
issue_level = IssueLevel.MAJOR
278-
elif relative_delta > self.issue_level_threshold:
279-
issue_level = IssueLevel.MEDIUM
280-
elif self.metric_direction == "better_higher":
281-
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
282-
issue_level = IssueLevel.MAJOR
283-
elif relative_delta < -self.issue_level_threshold:
284-
issue_level = IssueLevel.MEDIUM
285-
286-
return ScanResult(
287-
name=name,
288-
metric_name=metric_name,
289-
metric_value=metric_value,
290-
metric_reference_value=metric_reference_value,
291-
issue_level=issue_level,
292-
slice_size=size_data,
293-
filename_examples=filename_examples,
294-
relative_delta=relative_delta,
295-
issue_group=issue_group,
296-
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import os
2+
from abc import abstractmethod
3+
from importlib import import_module
4+
from pathlib import Path
5+
from typing import Any, Sequence
6+
7+
import cv2
8+
9+
from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
10+
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
11+
from giskard_vision.core.issues import Robustness
12+
from giskard_vision.core.tests.base import TestDiffBase
13+
14+
15+
class PerturbationBaseDetector(DetectorVisionBase):
16+
"""
17+
Abstract class for Landmark Detection Detectors
18+
19+
Methods:
20+
get_dataloaders(dataset: Any) -> Sequence[Any]:
21+
Abstract method that returns a list of dataloaders corresponding to
22+
slices or transformations
23+
24+
get_results(model: Any, dataset: Any) -> Sequence[ScanResult]:
25+
Returns a list of ScanResult containing the evaluation results
26+
27+
get_scan_result(self, test_result) -> ScanResult:
28+
Convert TestResult to ScanResult
29+
"""
30+
31+
issue_group = Robustness
32+
33+
def set_specs_from_model_type(self, model_type):
34+
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
35+
DetectorSpecs = getattr(module, "DetectorSpecs")
36+
37+
if DetectorSpecs:
38+
# Only set attributes that are not part of Python's special attributes (those starting with __)
39+
for attr_name, attr_value in vars(DetectorSpecs).items():
40+
if not attr_name.startswith("__") and hasattr(self, attr_name):
41+
setattr(self, attr_name, attr_value)
42+
else:
43+
raise ValueError(f"No detector specifications found for model type: {model_type}")
44+
45+
@abstractmethod
46+
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...
47+
48+
def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
49+
self.set_specs_from_model_type(model.model_type)
50+
dataloaders = self.get_dataloaders(dataset)
51+
52+
results = []
53+
for dl in dataloaders:
54+
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
55+
model=model,
56+
dataloader=dl,
57+
dataloader_ref=dataset,
58+
)
59+
60+
# Save example images from dataloader and dataset
61+
current_path = str(Path())
62+
os.makedirs(f"{current_path}/examples_images", exist_ok=True)
63+
filename_examples = []
64+
65+
index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0]
66+
67+
if isinstance(dl, FilteredDataLoader):
68+
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
69+
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
70+
filename_examples.append(filename_example_dataloader_ref)
71+
72+
filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
73+
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
74+
filename_examples.append(filename_example_dataloader)
75+
results.append(
76+
self.get_scan_result(
77+
test_result.metric_value_test,
78+
test_result.metric_value_ref,
79+
test_result.metric_name,
80+
filename_examples,
81+
dl.name,
82+
len(dl),
83+
issue_group=self.issue_group,
84+
)
85+
)
86+
87+
return results
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from giskard_vision.core.issues import IssueGroup
2+
from giskard_vision.image_classification.tests.performance import MetricBase
3+
4+
5+
class DetectorSpecsBase:
6+
issue_group: IssueGroup
7+
warning_messages: dict
8+
metric: MetricBase = None
9+
metric_type: str = None
10+
metric_direction: str = None
11+
deviation_threshold: float = 0.10
12+
issue_level_threshold: float = 0.05
13+
num_images: int = 0

giskard_vision/landmark_detection/detectors/transformation_blurring_detector.py giskard_vision/core/detectors/transformation_blurring_detector.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from giskard_vision.core.dataloaders.wrappers import BlurredDataLoader
22

33
from ...core.detectors.decorator import maybe_detector
4-
from .base import LandmarkDetectionBaseDetector, Robustness
4+
from .perturbation import PerturbationBaseDetector
55

66

7-
@maybe_detector("blurring_landmark", tags=["vision", "face", "landmark", "transformed", "blurred"])
8-
class TransformationBlurringDetectorLandmark(LandmarkDetectionBaseDetector):
7+
@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
8+
class TransformationBlurringDetector(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance on blurred images
1111
"""
1212

13-
issue_group = Robustness
14-
1513
def __init__(self, kernel_size=(11, 11), sigma=(3, 3)):
1614
self.kernel_size = kernel_size
1715
self.sigma = sigma

giskard_vision/landmark_detection/detectors/transformation_color_detector.py giskard_vision/core/detectors/transformation_color_detector.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader
22

33
from ...core.detectors.decorator import maybe_detector
4-
from .base import LandmarkDetectionBaseDetector, Robustness
4+
from .perturbation import PerturbationBaseDetector
55

66

7-
@maybe_detector("color_landmark", tags=["vision", "face", "landmark", "filtered", "colored"])
8-
class TransformationColorDetectorLandmark(LandmarkDetectionBaseDetector):
7+
@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
8+
class TransformationColorDetector(PerturbationBaseDetector):
99
"""
1010
Detector that evaluates models performance depending on images in grayscale
1111
"""
1212

13-
issue_group = Robustness
14-
1513
def get_dataloaders(self, dataset):
1614
dl = ColoredDataLoader(dataset)
1715

0 commit comments

Comments
 (0)