1
1
import os
2
2
from abc import abstractmethod
3
+ from importlib import import_module
3
4
from pathlib import Path
4
- from typing import Any , Sequence
5
+ from typing import Any , Sequence , Tuple
5
6
6
7
import cv2
7
8
8
9
from giskard_vision .core .dataloaders .wrappers import FilteredDataLoader
9
- from giskard_vision .core .detectors .base import (
10
- DetectorVisionBase ,
11
- IssueGroup ,
12
- ScanResult ,
13
- )
10
+ from giskard_vision .core .detectors .base import DetectorVisionBase , ScanResult
11
+ from giskard_vision .core .issues import Robustness
14
12
from giskard_vision .core .tests .base import TestDiffBase
15
- from giskard_vision .utils .errors import GiskardImportError
16
-
17
- from .metrics import detector_metrics
18
-
19
- Robustness = IssueGroup (
20
- "Robustness" ,
21
- description = "Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations." ,
22
- )
23
13
24
14
25
15
class PerturbationBaseDetector (DetectorVisionBase ):
@@ -40,6 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):
40
30
41
31
issue_group = Robustness
42
32
33
+ def run (
34
+ self ,
35
+ model : Any ,
36
+ dataset : Any ,
37
+ features : Any | None = None ,
38
+ issue_levels : Tuple [Any ] = None ,
39
+ embed : bool = True ,
40
+ num_images : int = 0 ,
41
+ ) -> Sequence [Any ]:
42
+ module = import_module (f"giskard_vision.{ model .model_type } .detectors.specs" )
43
+ DetectorSpecs = getattr (module , "DetectorSpecs" )
44
+
45
+ if DetectorSpecs :
46
+ # Only set attributes that are not part of Python's special attributes (those starting with __)
47
+ for attr_name , attr_value in vars (DetectorSpecs ).items ():
48
+ if not attr_name .startswith ("__" ) and hasattr (self , attr_name ):
49
+ setattr (self , attr_name , attr_value )
50
+ else :
51
+ raise ValueError (f"No detector specifications found for model type: { model .model_type } " )
52
+
53
+ return super ().run (model , dataset , features , issue_levels , embed , num_images )
54
+
43
55
@abstractmethod
44
56
def get_dataloaders (self , dataset : Any ) -> Sequence [Any ]: ...
45
57
@@ -48,7 +60,7 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
48
60
49
61
results = []
50
62
for dl in dataloaders :
51
- test_result = TestDiffBase (metric = detector_metrics [ model . model_type ] , threshold = 1 ).run (
63
+ test_result = TestDiffBase (metric = self . metric , threshold = 1 ).run (
52
64
model = model ,
53
65
dataloader = dl ,
54
66
dataloader_ref = dataset ,
@@ -63,40 +75,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
63
75
64
76
if isinstance (dl , FilteredDataLoader ):
65
77
filename_example_dataloader_ref = str (Path () / "examples_images" / f"{ dataset .name } _{ index_worst } .png" )
66
- cv2 .imwrite (
67
- filename_example_dataloader_ref , cv2 .resize (dataset [index_worst ][0 ][0 ], (0 , 0 ), fx = 0.3 , fy = 0.3 )
68
- )
78
+ cv2 .imwrite (filename_example_dataloader_ref , dataset [index_worst ][0 ][0 ])
69
79
filename_examples .append (filename_example_dataloader_ref )
70
80
71
81
filename_example_dataloader = str (Path () / "examples_images" / f"{ dl .name } _{ index_worst } .png" )
72
- cv2 .imwrite (filename_example_dataloader , cv2 . resize ( dl [index_worst ][0 ][0 ], ( 0 , 0 ), fx = 0.3 , fy = 0.3 ) )
82
+ cv2 .imwrite (filename_example_dataloader , dl [index_worst ][0 ][0 ])
73
83
filename_examples .append (filename_example_dataloader )
74
- results .append (self .get_scan_result (test_result , filename_examples , dl .name , len (dl )))
84
+ results .append (
85
+ self .get_scan_result (
86
+ test_result .metric_value_test ,
87
+ test_result .metric_value_test ,
88
+ test_result .metric_name ,
89
+ filename_examples ,
90
+ dl .name ,
91
+ len (dl ),
92
+ issue_group = self .issue_group ,
93
+ )
94
+ )
75
95
76
96
return results
77
-
78
- def get_scan_result (self , test_result , filename_examples , name , size_data ) -> ScanResult :
79
- try :
80
- from giskard .scanner .issues import IssueLevel
81
- except (ImportError , ModuleNotFoundError ) as e :
82
- raise GiskardImportError (["giskard" ]) from e
83
-
84
- relative_delta = (test_result .metric_value_test - test_result .metric_value_ref ) / test_result .metric_value_ref
85
-
86
- if relative_delta > self .issue_level_threshold + self .deviation_threshold :
87
- issue_level = IssueLevel .MAJOR
88
- elif relative_delta > self .issue_level_threshold :
89
- issue_level = IssueLevel .MEDIUM
90
- else :
91
- issue_level = IssueLevel .MINOR
92
-
93
- return ScanResult (
94
- name = name ,
95
- metric_name = test_result .metric_name ,
96
- metric_value = test_result .metric_value_test ,
97
- metric_reference_value = test_result .metric_value_ref ,
98
- issue_level = issue_level ,
99
- slice_size = size_data ,
100
- filename_examples = filename_examples ,
101
- relative_delta = relative_delta ,
102
- )
0 commit comments