From c359c9c1586440f60ec1c0d902f6eaf5758078aa Mon Sep 17 00:00:00 2001 From: Rabah Khalek Date: Tue, 13 Aug 2024 13:04:40 +0200 Subject: [PATCH] refactored spec setting --- giskard_vision/core/detectors/perturbation.py | 19 +++++-------------- .../detectors/test_detectors.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/giskard_vision/core/detectors/perturbation.py b/giskard_vision/core/detectors/perturbation.py index e6a33811..cb5c0c26 100644 --- a/giskard_vision/core/detectors/perturbation.py +++ b/giskard_vision/core/detectors/perturbation.py @@ -2,7 +2,7 @@ from abc import abstractmethod from importlib import import_module from pathlib import Path -from typing import Any, Sequence, Tuple +from typing import Any, Sequence import cv2 @@ -30,16 +30,8 @@ class PerturbationBaseDetector(DetectorVisionBase): issue_group = Robustness - def run( - self, - model: Any, - dataset: Any, - features: Any | None = None, - issue_levels: Tuple[Any] = None, - embed: bool = True, - num_images: int = 0, - ) -> Sequence[Any]: - module = import_module(f"giskard_vision.{model.model_type}.detectors.specs") + def set_specs_from_model_type(self, model_type): + module = import_module(f"giskard_vision.{model_type}.detectors.specs") DetectorSpecs = getattr(module, "DetectorSpecs") if DetectorSpecs: @@ -48,14 +40,13 @@ def run( if not attr_name.startswith("__") and hasattr(self, attr_name): setattr(self, attr_name, attr_value) else: - raise ValueError(f"No detector specifications found for model type: {model.model_type}") - - return super().run(model, dataset, features, issue_levels, embed, num_images) + raise ValueError(f"No detector specifications found for model type: {model_type}") @abstractmethod def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ... def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]: + self.set_specs_from_model_type(model.model_type) dataloaders = self.get_dataloaders(dataset) results = [] diff --git a/tests/landmark_detection/detectors/test_detectors.py b/tests/landmark_detection/detectors/test_detectors.py index f812e740..7170b7a0 100644 --- a/tests/landmark_detection/detectors/test_detectors.py +++ b/tests/landmark_detection/detectors/test_detectors.py @@ -1,11 +1,18 @@ from giskard.scanner.issues import Issue, IssueLevel from pytest import mark +from giskard_vision.core.detectors.transformation_blurring_detector import ( + TransformationBlurringDetectorLandmark, +) +from giskard_vision.core.detectors.transformation_color_detector import ( + TransformationColorDetectorLandmark, +) +from giskard_vision.core.detectors.transformation_noise_detector import ( + TransformationNoiseDetectorLandmark, +) from giskard_vision.landmark_detection.detectors import ( CroppingDetectorLandmark, MetaDataScanDetectorLandmark, - TransformationBlurringDetectorLandmark, - TransformationColorDetectorLandmark, TransformationResizeDetectorLandmark, ) from giskard_vision.landmark_detection.detectors.base import ScanResult @@ -17,6 +24,7 @@ CroppingDetectorLandmark, TransformationBlurringDetectorLandmark, TransformationColorDetectorLandmark, + TransformationNoiseDetectorLandmark, TransformationResizeDetectorLandmark, ], )