Skip to content

Commit

Permalink
refactored spec setting
Browse files Browse the repository at this point in the history
  • Loading branch information
rabah-khalek committed Aug 13, 2024
1 parent fe26272 commit c359c9c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
19 changes: 5 additions & 14 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence, Tuple
from typing import Any, Sequence

import cv2

Expand Down Expand Up @@ -30,16 +30,8 @@ class PerturbationBaseDetector(DetectorVisionBase):

issue_group = Robustness

def run(
self,
model: Any,
dataset: Any,
features: Any | None = None,
issue_levels: Tuple[Any] = None,
embed: bool = True,
num_images: int = 0,
) -> Sequence[Any]:
module = import_module(f"giskard_vision.{model.model_type}.detectors.specs")
def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
Expand All @@ -48,14 +40,13 @@ def run(
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model.model_type}")

return super().run(model, dataset, features, issue_levels, embed, num_images)
raise ValueError(f"No detector specifications found for model type: {model_type}")

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
self.set_specs_from_model_type(model.model_type)
dataloaders = self.get_dataloaders(dataset)

results = []
Expand Down
12 changes: 10 additions & 2 deletions tests/landmark_detection/detectors/test_detectors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from giskard.scanner.issues import Issue, IssueLevel
from pytest import mark

from giskard_vision.core.detectors.transformation_blurring_detector import (
TransformationBlurringDetectorLandmark,
)
from giskard_vision.core.detectors.transformation_color_detector import (
TransformationColorDetectorLandmark,
)
from giskard_vision.core.detectors.transformation_noise_detector import (
TransformationNoiseDetectorLandmark,
)
from giskard_vision.landmark_detection.detectors import (
CroppingDetectorLandmark,
MetaDataScanDetectorLandmark,
TransformationBlurringDetectorLandmark,
TransformationColorDetectorLandmark,
TransformationResizeDetectorLandmark,
)
from giskard_vision.landmark_detection.detectors.base import ScanResult
Expand All @@ -17,6 +24,7 @@
CroppingDetectorLandmark,
TransformationBlurringDetectorLandmark,
TransformationColorDetectorLandmark,
TransformationNoiseDetectorLandmark,
TransformationResizeDetectorLandmark,
],
)
Expand Down

0 comments on commit c359c9c

Please sign in to comment.