Skip to content

Commit

Permalink
Merge pull request #53 from Giskard-AI/refactoring-detectors
Browse files Browse the repository at this point in the history
refactoring detectors
  • Loading branch information
rabah-khalek authored Aug 13, 2024
2 parents 6601ecb + 6ba1994 commit ffbb425
Show file tree
Hide file tree
Showing 29 changed files with 178 additions and 230 deletions.
15 changes: 1 addition & 14 deletions giskard_vision/core/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,10 @@
get_image_channel_number,
get_image_size,
)
from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import AttributesIssueMeta

from ..types import TypesBase

EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)


class DataIteratorBase(ABC):
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.
Expand Down
3 changes: 2 additions & 1 deletion giskard_vision/core/dataloaders/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
from giskard_vision.core.dataloaders.base import DataIteratorBase
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
from giskard_vision.core.issues import AttributesIssueMeta
from giskard_vision.utils.errors import GiskardError, GiskardImportError


Expand Down
2 changes: 1 addition & 1 deletion giskard_vision/core/dataloaders/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import IssueGroup


class MetaData:
Expand Down
51 changes: 39 additions & 12 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple

from giskard_vision.core.issues import IssueGroup
from giskard_vision.utils.errors import GiskardImportError


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str
from .specs import DetectorSpecsBase


@dataclass
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
}


class DetectorVisionBase:
class DetectorVisionBase(DetectorSpecsBase):
"""
Abstract class for Vision Detectors
Expand All @@ -67,12 +64,6 @@ class DetectorVisionBase:
evaluation results for the scan.
"""

issue_group: IssueGroup
warning_messages: dict
issue_level_threshold: float = 0.2
deviation_threshold: float = 0.05
num_images: int = 0

def run(
self,
model: Any,
Expand Down Expand Up @@ -139,6 +130,42 @@ def get_issues(

return issues

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)

@abstractmethod
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
"""Returns a list of ScanResult
Expand Down
38 changes: 1 addition & 37 deletions giskard_vision/core/detectors/metadata_scan_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import pandas as pd

from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import PerformanceIssueMeta
from giskard_vision.core.tests.base import MetricBase
from giskard_vision.utils.errors import GiskardImportError

Expand Down Expand Up @@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
pass

return pd.DataFrame(df)

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)
9 changes: 0 additions & 9 deletions giskard_vision/core/detectors/metrics.py

This file was deleted.

75 changes: 30 additions & 45 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
import os
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import (
DetectorVisionBase,
IssueGroup,
ScanResult,
)
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase
from giskard_vision.utils.errors import GiskardImportError

from .metrics import detector_metrics

Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)


class PerturbationBaseDetector(DetectorVisionBase):
Expand All @@ -40,15 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):

issue_group = Robustness

def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
# Only set attributes that are not part of Python's special attributes (those starting with __)
for attr_name, attr_value in vars(DetectorSpecs).items():
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model_type}")

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
self.set_specs_from_model_type(model.model_type)
dataloaders = self.get_dataloaders(dataset)

results = []
for dl in dataloaders:
test_result = TestDiffBase(metric=detector_metrics[model.model_type], threshold=1).run(
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
model=model,
dataloader=dl,
dataloader_ref=dataset,
Expand All @@ -63,40 +66,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(
filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)
)
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3))
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
filename_examples.append(filename_example_dataloader)
results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl)))
results.append(
self.get_scan_result(
test_result.metric_value_test,
test_result.metric_value_ref,
test_result.metric_name,
filename_examples,
dl.name,
len(dl),
issue_group=self.issue_group,
)
)

return results

def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref

if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
else:
issue_level = IssueLevel.MINOR

return ScanResult(
name=name,
metric_name=test_result.metric_name,
metric_value=test_result.metric_value_test,
metric_reference_value=test_result.metric_value_ref,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
)
13 changes: 13 additions & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from giskard_vision.core.issues import IssueGroup
from giskard_vision.image_classification.tests.performance import MetricBase


class DetectorSpecsBase:
issue_group: IssueGroup
warning_messages: dict
metric: MetricBase = None
metric_type: str = None
metric_direction: str = None
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationBlurringDetectorLandmark(PerturbationBaseDetector):
class TransformationBlurringDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on blurred images
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader

from ...core.detectors.decorator import maybe_detector
from .perturbation import PerturbationBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationColorDetectorLandmark(PerturbationBaseDetector):
class TransformationColorDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance depending on images in grayscale
"""

issue_group = Robustness

def get_dataloaders(self, dataset):
dl = ColoredDataLoader(dataset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from .perturbation import PerturbationBaseDetector


@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationNoiseDetectorLandmark(PerturbationBaseDetector):
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "noise"])
class TransformationNoiseDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on noisy images
"""
Expand Down
25 changes: 25 additions & 0 deletions giskard_vision/core/issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str


EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)
Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)
2 changes: 1 addition & 1 deletion giskard_vision/image_classification/dataloaders/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.core.dataloaders.hf import HFDataLoader
from giskard_vision.core.dataloaders.meta import MetaData
from giskard_vision.core.dataloaders.tfds import DataLoaderTensorFlowDatasets
from giskard_vision.core.dataloaders.utils import flatten_dict
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.image_classification.types import Types


Expand Down
Loading

0 comments on commit ffbb425

Please sign in to comment.