1
1
import os
2
2
from abc import abstractmethod
3
+ from importlib import import_module
3
4
from pathlib import Path
4
5
from typing import Any , Sequence
5
6
6
7
import cv2
7
8
8
9
from giskard_vision .core .dataloaders .wrappers import FilteredDataLoader
9
- from giskard_vision .core .detectors .base import (
10
- DetectorVisionBase ,
11
- IssueGroup ,
12
- ScanResult ,
13
- )
10
+ from giskard_vision .core .detectors .base import DetectorVisionBase , ScanResult
11
+ from giskard_vision .core .issues import Robustness
14
12
from giskard_vision .core .tests .base import TestDiffBase
15
- from giskard_vision .utils .errors import GiskardImportError
16
-
17
- from .metrics import detector_metrics
18
-
19
- Robustness = IssueGroup (
20
- "Robustness" ,
21
- description = "Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations." ,
22
- )
23
13
24
14
25
15
class PerturbationBaseDetector (DetectorVisionBase ):
@@ -40,15 +30,28 @@ class PerturbationBaseDetector(DetectorVisionBase):
40
30
41
31
issue_group = Robustness
42
32
33
+ def set_specs_from_model_type (self , model_type ):
34
+ module = import_module (f"giskard_vision.{ model_type } .detectors.specs" )
35
+ DetectorSpecs = getattr (module , "DetectorSpecs" )
36
+
37
+ if DetectorSpecs :
38
+ # Only set attributes that are not part of Python's special attributes (those starting with __)
39
+ for attr_name , attr_value in vars (DetectorSpecs ).items ():
40
+ if not attr_name .startswith ("__" ) and hasattr (self , attr_name ):
41
+ setattr (self , attr_name , attr_value )
42
+ else :
43
+ raise ValueError (f"No detector specifications found for model type: { model_type } " )
44
+
43
45
@abstractmethod
44
46
def get_dataloaders (self , dataset : Any ) -> Sequence [Any ]: ...
45
47
46
48
def get_results (self , model : Any , dataset : Any ) -> Sequence [ScanResult ]:
49
+ self .set_specs_from_model_type (model .model_type )
47
50
dataloaders = self .get_dataloaders (dataset )
48
51
49
52
results = []
50
53
for dl in dataloaders :
51
- test_result = TestDiffBase (metric = detector_metrics [ model . model_type ] , threshold = 1 ).run (
54
+ test_result = TestDiffBase (metric = self . metric , threshold = 1 ).run (
52
55
model = model ,
53
56
dataloader = dl ,
54
57
dataloader_ref = dataset ,
@@ -63,40 +66,22 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
63
66
64
67
if isinstance (dl , FilteredDataLoader ):
65
68
filename_example_dataloader_ref = str (Path () / "examples_images" / f"{ dataset .name } _{ index_worst } .png" )
66
- cv2 .imwrite (
67
- filename_example_dataloader_ref , cv2 .resize (dataset [index_worst ][0 ][0 ], (0 , 0 ), fx = 0.3 , fy = 0.3 )
68
- )
69
+ cv2 .imwrite (filename_example_dataloader_ref , dataset [index_worst ][0 ][0 ])
69
70
filename_examples .append (filename_example_dataloader_ref )
70
71
71
72
filename_example_dataloader = str (Path () / "examples_images" / f"{ dl .name } _{ index_worst } .png" )
72
- cv2 .imwrite (filename_example_dataloader , cv2 . resize ( dl [index_worst ][0 ][0 ], ( 0 , 0 ), fx = 0.3 , fy = 0.3 ) )
73
+ cv2 .imwrite (filename_example_dataloader , dl [index_worst ][0 ][0 ])
73
74
filename_examples .append (filename_example_dataloader )
74
- results .append (self .get_scan_result (test_result , filename_examples , dl .name , len (dl )))
75
+ results .append (
76
+ self .get_scan_result (
77
+ test_result .metric_value_test ,
78
+ test_result .metric_value_ref ,
79
+ test_result .metric_name ,
80
+ filename_examples ,
81
+ dl .name ,
82
+ len (dl ),
83
+ issue_group = self .issue_group ,
84
+ )
85
+ )
75
86
76
87
return results
77
-
78
- def get_scan_result (self , test_result , filename_examples , name , size_data ) -> ScanResult :
79
- try :
80
- from giskard .scanner .issues import IssueLevel
81
- except (ImportError , ModuleNotFoundError ) as e :
82
- raise GiskardImportError (["giskard" ]) from e
83
-
84
- relative_delta = (test_result .metric_value_test - test_result .metric_value_ref ) / test_result .metric_value_ref
85
-
86
- if relative_delta > self .issue_level_threshold + self .deviation_threshold :
87
- issue_level = IssueLevel .MAJOR
88
- elif relative_delta > self .issue_level_threshold :
89
- issue_level = IssueLevel .MEDIUM
90
- else :
91
- issue_level = IssueLevel .MINOR
92
-
93
- return ScanResult (
94
- name = name ,
95
- metric_name = test_result .metric_name ,
96
- metric_value = test_result .metric_value_test ,
97
- metric_reference_value = test_result .metric_value_ref ,
98
- issue_level = issue_level ,
99
- slice_size = size_data ,
100
- filename_examples = filename_examples ,
101
- relative_delta = relative_delta ,
102
- )
0 commit comments