|
| 1 | +import os |
| 2 | +from abc import abstractmethod |
| 3 | +from importlib import import_module |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Sequence |
| 6 | + |
| 7 | +import cv2 |
| 8 | + |
| 9 | +from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader |
| 10 | +from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult |
| 11 | +from giskard_vision.core.issues import Robustness |
| 12 | +from giskard_vision.core.tests.base import TestDiffBase |
| 13 | + |
| 14 | + |
| 15 | +class PerturbationBaseDetector(DetectorVisionBase): |
| 16 | + """ |
| 17 | + Abstract class for Landmark Detection Detectors |
| 18 | +
|
| 19 | + Methods: |
| 20 | + get_dataloaders(dataset: Any) -> Sequence[Any]: |
| 21 | + Abstract method that returns a list of dataloaders corresponding to |
| 22 | + slices or transformations |
| 23 | +
|
| 24 | + get_results(model: Any, dataset: Any) -> Sequence[ScanResult]: |
| 25 | + Returns a list of ScanResult containing the evaluation results |
| 26 | +
|
| 27 | + get_scan_result(self, test_result) -> ScanResult: |
| 28 | + Convert TestResult to ScanResult |
| 29 | + """ |
| 30 | + |
| 31 | + issue_group = Robustness |
| 32 | + |
| 33 | + def set_specs_from_model_type(self, model_type): |
| 34 | + module = import_module(f"giskard_vision.{model_type}.detectors.specs") |
| 35 | + DetectorSpecs = getattr(module, "DetectorSpecs") |
| 36 | + |
| 37 | + if DetectorSpecs: |
| 38 | + # Only set attributes that are not part of Python's special attributes (those starting with __) |
| 39 | + for attr_name, attr_value in vars(DetectorSpecs).items(): |
| 40 | + if not attr_name.startswith("__") and hasattr(self, attr_name): |
| 41 | + setattr(self, attr_name, attr_value) |
| 42 | + else: |
| 43 | + raise ValueError(f"No detector specifications found for model type: {model_type}") |
| 44 | + |
| 45 | + @abstractmethod |
| 46 | + def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ... |
| 47 | + |
| 48 | + def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]: |
| 49 | + self.set_specs_from_model_type(model.model_type) |
| 50 | + dataloaders = self.get_dataloaders(dataset) |
| 51 | + |
| 52 | + results = [] |
| 53 | + for dl in dataloaders: |
| 54 | + test_result = TestDiffBase(metric=self.metric, threshold=1).run( |
| 55 | + model=model, |
| 56 | + dataloader=dl, |
| 57 | + dataloader_ref=dataset, |
| 58 | + ) |
| 59 | + |
| 60 | + # Save example images from dataloader and dataset |
| 61 | + current_path = str(Path()) |
| 62 | + os.makedirs(f"{current_path}/examples_images", exist_ok=True) |
| 63 | + filename_examples = [] |
| 64 | + |
| 65 | + index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0] |
| 66 | + |
| 67 | + if isinstance(dl, FilteredDataLoader): |
| 68 | + filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png") |
| 69 | + cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0]) |
| 70 | + filename_examples.append(filename_example_dataloader_ref) |
| 71 | + |
| 72 | + filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png") |
| 73 | + cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0]) |
| 74 | + filename_examples.append(filename_example_dataloader) |
| 75 | + results.append( |
| 76 | + self.get_scan_result( |
| 77 | + test_result.metric_value_test, |
| 78 | + test_result.metric_value_ref, |
| 79 | + test_result.metric_name, |
| 80 | + filename_examples, |
| 81 | + dl.name, |
| 82 | + len(dl), |
| 83 | + issue_group=self.issue_group, |
| 84 | + ) |
| 85 | + ) |
| 86 | + |
| 87 | + return results |
0 commit comments