Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring detectors #53

Merged
merged 8 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions giskard_vision/core/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,10 @@
get_image_channel_number,
get_image_size,
)
from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import AttributesIssueMeta

from ..types import TypesBase

EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)


class DataIteratorBase(ABC):
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.
Expand Down
3 changes: 2 additions & 1 deletion giskard_vision/core/dataloaders/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
from giskard_vision.core.dataloaders.base import DataIteratorBase
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
from giskard_vision.core.issues import AttributesIssueMeta
from giskard_vision.utils.errors import GiskardError, GiskardImportError


Expand Down
2 changes: 1 addition & 1 deletion giskard_vision/core/dataloaders/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import IssueGroup


class MetaData:
Expand Down
51 changes: 39 additions & 12 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple

from giskard_vision.core.issues import IssueGroup
from giskard_vision.utils.errors import GiskardImportError


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str
from .specs import DetectorSpecsBase


@dataclass
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
}


class DetectorVisionBase:
class DetectorVisionBase(DetectorSpecsBase):
"""
Abstract class for Vision Detectors
Expand All @@ -67,12 +64,6 @@ class DetectorVisionBase:
evaluation results for the scan.
"""

issue_group: IssueGroup
warning_messages: dict
issue_level_threshold: float = 0.2
deviation_threshold: float = 0.05
num_images: int = 0

def run(
self,
model: Any,
Expand Down Expand Up @@ -139,6 +130,42 @@ def get_issues(

return issues

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)

@abstractmethod
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
"""Returns a list of ScanResult
Expand Down
38 changes: 1 addition & 37 deletions giskard_vision/core/detectors/metadata_scan_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import pandas as pd

from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import PerformanceIssueMeta
from giskard_vision.core.tests.base import MetricBase
from giskard_vision.utils.errors import GiskardImportError

Expand Down Expand Up @@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
pass

return pd.DataFrame(df)

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)
9 changes: 0 additions & 9 deletions giskard_vision/core/detectors/metrics.py

This file was deleted.

75 changes: 30 additions & 45 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
import os
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import (
DetectorVisionBase,
IssueGroup,
ScanResult,
)
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase
from giskard_vision.utils.errors import GiskardImportError

from .metrics import detector_metrics

Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)


class PerturbationBaseDetector(DetectorVisionBase):
Expand All @@ -40,15 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):

issue_group = Robustness

def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
# Only set attributes that are not part of Python's special attributes (those starting with __)
for attr_name, attr_value in vars(DetectorSpecs).items():
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model_type}")

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
self.set_specs_from_model_type(model.model_type)
dataloaders = self.get_dataloaders(dataset)

results = []
for dl in dataloaders:
test_result = TestDiffBase(metric=detector_metrics[model.model_type], threshold=1).run(
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
model=model,
dataloader=dl,
dataloader_ref=dataset,
Expand All @@ -63,40 +66,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(
filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)
)
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3))
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
filename_examples.append(filename_example_dataloader)
results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl)))
results.append(
self.get_scan_result(
test_result.metric_value_test,
test_result.metric_value_ref,
test_result.metric_name,
filename_examples,
dl.name,
len(dl),
issue_group=self.issue_group,
)
)

return results

def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref

if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
else:
issue_level = IssueLevel.MINOR

return ScanResult(
name=name,
metric_name=test_result.metric_name,
metric_value=test_result.metric_value_test,
metric_reference_value=test_result.metric_value_ref,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
)
13 changes: 13 additions & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from giskard_vision.core.issues import IssueGroup
from giskard_vision.image_classification.tests.performance import MetricBase


class DetectorSpecsBase:
issue_group: IssueGroup
warning_messages: dict
metric: MetricBase = None
metric_type: str = None
metric_direction: str = None
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationBlurringDetectorLandmark(PerturbationBaseDetector):
class TransformationBlurringDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on blurred images
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader

from ...core.detectors.decorator import maybe_detector
from .perturbation import PerturbationBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationColorDetectorLandmark(PerturbationBaseDetector):
class TransformationColorDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance depending on images in grayscale
"""

issue_group = Robustness

def get_dataloaders(self, dataset):
dl = ColoredDataLoader(dataset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from .perturbation import PerturbationBaseDetector


@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationNoiseDetectorLandmark(PerturbationBaseDetector):
@maybe_detector("noise", tags=["vision", "robustness", "image_classification", "landmark", "object_detection", "noise"])
class TransformationNoiseDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on noisy images
"""
Expand Down
25 changes: 25 additions & 0 deletions giskard_vision/core/issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str


EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)
Robustness = IssueGroup(
"Robustness",
description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.",
)
2 changes: 1 addition & 1 deletion giskard_vision/image_classification/dataloaders/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.core.dataloaders.hf import HFDataLoader
from giskard_vision.core.dataloaders.meta import MetaData
from giskard_vision.core.dataloaders.tfds import DataLoaderTensorFlowDatasets
from giskard_vision.core.dataloaders.utils import flatten_dict
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
from giskard_vision.image_classification.types import Types


Expand Down
Loading
Loading