Skip to content

Commit

Permalink
refactoring detectors
Browse files Browse the repository at this point in the history
  • Loading branch information
rabah-khalek committed Aug 13, 2024
1 parent e547d4d commit a44399d
Show file tree
Hide file tree
Showing 24 changed files with 165 additions and 227 deletions.
10 changes: 0 additions & 10 deletions giskard_vision/core/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,9 @@
import numpy as np

from giskard_vision.core.dataloaders.meta import MetaData
from giskard_vision.core.detectors.base import IssueGroup

from ..types import TypesBase

EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)


class DataIteratorBase(ABC):
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.
Expand Down
2 changes: 1 addition & 1 deletion giskard_vision/core/dataloaders/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional

from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import IssueGroup


class MetaData:
Expand Down
51 changes: 39 additions & 12 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple

from giskard_vision.core.issues import IssueGroup
from giskard_vision.utils.errors import GiskardImportError


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str
from .specs import DetectorSpecsBase


@dataclass
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
}


class DetectorVisionBase:
class DetectorVisionBase(DetectorSpecsBase):
"""
Abstract class for Vision Detectors
Expand All @@ -67,12 +64,6 @@ class DetectorVisionBase:
evaluation results for the scan.
"""

issue_group: IssueGroup
warning_messages: dict
issue_level_threshold: float = 0.2
deviation_threshold: float = 0.05
num_images: int = 0

def run(
self,
model: Any,
Expand Down Expand Up @@ -139,6 +130,42 @@ def get_issues(

return issues

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)

@abstractmethod
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
"""Returns a list of ScanResult
Expand Down
43 changes: 2 additions & 41 deletions giskard_vision/core/detectors/metadata_scan_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import numpy as np
import pandas as pd

from giskard_vision.core.detectors.base import (
DetectorVisionBase,
IssueGroup,
ScanResult,
)
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import IssueGroup
from giskard_vision.core.tests.base import MetricBase
from giskard_vision.utils.errors import GiskardImportError

Expand Down Expand Up @@ -263,39 +260,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
pass

return pd.DataFrame(df)

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)
9 changes: 0 additions & 9 deletions giskard_vision/core/detectors/metrics.py

This file was deleted.

86 changes: 40 additions & 46 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
import os
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Sequence, Tuple

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import (
DetectorVisionBase,
IssueGroup,
ScanResult,
)
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase
from giskard_vision.utils.errors import GiskardImportError

from .metrics import detector_metrics

Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)


class PerturbationBaseDetector(DetectorVisionBase):
Expand All @@ -40,6 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):

issue_group = Robustness

def run(
self,
model: Any,
dataset: Any,
features: Any | None = None,
issue_levels: Tuple[Any] = None,
embed: bool = True,
num_images: int = 0,
) -> Sequence[Any]:
module = import_module(f"giskard_vision.{model.model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
# Only set attributes that are not part of Python's special attributes (those starting with __)
for attr_name, attr_value in vars(DetectorSpecs).items():
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model.model_type}")

return super().run(model, dataset, features, issue_levels, embed, num_images)

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

Expand All @@ -48,7 +60,7 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:

results = []
for dl in dataloaders:
test_result = TestDiffBase(metric=detector_metrics[model.model_type], threshold=1).run(
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
model=model,
dataloader=dl,
dataloader_ref=dataset,
Expand All @@ -63,40 +75,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(
filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)
)
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3))
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
filename_examples.append(filename_example_dataloader)
results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl)))
results.append(
self.get_scan_result(
test_result.metric_value_test,
test_result.metric_value_test,
test_result.metric_name,
filename_examples,
dl.name,
len(dl),
issue_group=self.issue_group,
)
)

return results

def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref

if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
else:
issue_level = IssueLevel.MINOR

return ScanResult(
name=name,
metric_name=test_result.metric_name,
metric_value=test_result.metric_value_test,
metric_reference_value=test_result.metric_value_ref,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
)
13 changes: 13 additions & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from giskard_vision.core.issues import IssueGroup
from giskard_vision.image_classification.tests.performance import MetricBase


class DetectorSpecsBase:
issue_group: IssueGroup
warning_messages: dict
metric: MetricBase = None
metric_type: str = None
metric_direction: str = None
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader

from ...core.detectors.decorator import maybe_detector
from .perturbation import PerturbationBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
Expand All @@ -10,8 +10,6 @@ class TransformationColorDetectorLandmark(PerturbationBaseDetector):
Detector that evaluates models performance depending on images in grayscale
"""

issue_group = Robustness

def get_dataloaders(self, dataset):
dl = ColoredDataLoader(dataset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .perturbation import PerturbationBaseDetector


@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "noise"])
class TransformationNoiseDetectorLandmark(PerturbationBaseDetector):
"""
Detector that evaluates models performance on noisy images
Expand Down
21 changes: 21 additions & 0 deletions giskard_vision/core/issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str


EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)
2 changes: 1 addition & 1 deletion giskard_vision/image_classification/dataloaders/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import numpy as np

from giskard_vision.core.dataloaders.base import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.core.dataloaders.hf import HFDataLoader
from giskard_vision.core.dataloaders.meta import MetaData
from giskard_vision.core.dataloaders.tfds import DataLoaderTensorFlowDatasets
from giskard_vision.core.dataloaders.utils import flatten_dict
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.image_classification.types import Types


Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from giskard_vision.core.detectors.metadata_scan_detector import MetaDataScanDetector
from giskard_vision.image_classification.tests.performance import Accuracy

from ...core.detectors.decorator import maybe_detector
from .specs import DetectorSpecs


@maybe_detector("metadata_classification", tags=["vision", "image_classification", "metadata"])
class MetaDataScanDetectorClassification(MetaDataScanDetector):
metric = Accuracy
type_task = "classification"
metric_type = "absolute"
metric_direction = "better_higher"
deviation_threshold = 0.10
issue_level_threshold = 0.05
class MetaDataScanDetectorClassification(DetectorSpecs, MetaDataScanDetector):
pass
Loading

0 comments on commit a44399d

Please sign in to comment.