2424class Roc (object ):
2525
2626 """ Roc Class
27-
28- The Roc class is based on Tor Wager's Matlab roc_plot.m function and allows a user to easily run different types of
29- receiver operator characteristic curves. For example, one might be interested in single interval or forced choice.
27+
28+ The Roc class is based on Tor Wager's Matlab roc_plot.m function and
29+ allows a user to easily run different types of receiver operator
30+ characteristic curves. For example, one might be interested in single
31+ interval or forced choice.
3032
3133 Args:
3234 input_values: nibabel data instance
3335 binary_outcome: vector of training labels
34- threshold_type: ['optimal_overall', 'optimal_balanced','minimum_sdt_bias']
35- **kwargs: Additional keyword arguments to pass to the prediction algorithm
36+ threshold_type: ['optimal_overall', 'optimal_balanced',
37+ 'minimum_sdt_bias']
38+ **kwargs: Additional keyword arguments to pass to the prediction
39+ algorithm
3640
3741 """
3842
39- def __init__ (self , input_values = None , binary_outcome = None ,
43+ def __init__ (self , input_values = None , binary_outcome = None ,
4044 threshold_type = 'optimal_overall' , forced_choice = None , ** kwargs ):
4145 if len (input_values ) != len (binary_outcome ):
42- raise ValueError ("Data Problem: input_value and binary_outcome are different lengths." )
46+ raise ValueError ("Data Problem: input_value and binary_outcome"
47+ "are different lengths." )
4348
4449 if not any (binary_outcome ):
4550 raise ValueError ("Data Problem: binary_outcome may not be boolean" )
4651
4752 thr_type = ['optimal_overall' , 'optimal_balanced' ,'minimum_sdt_bias' ]
4853 if threshold_type not in thr_type :
49- raise ValueError ("threshold_type must be ['optimal_overall', 'optimal_balanced','minimum_sdt_bias']" )
54+ raise ValueError ("threshold_type must be ['optimal_overall', "
55+ "'optimal_balanced','minimum_sdt_bias']" )
5056
5157 self .input_values = deepcopy (input_values )
5258 self .binary_outcome = deepcopy (binary_outcome )
@@ -55,25 +61,32 @@ def __init__(self, input_values=None, binary_outcome=None,
5561
5662 if isinstance (self .binary_outcome ,pd .DataFrame ):
5763 self .binary_outcome = np .array (self .binary_outcome ).flatten ()
58- else :
64+ else :
5965 self .binary_outcome = deepcopy (binary_outcome )
6066
61- def calculate (self , input_values = None , binary_outcome = None , criterion_values = None ,
62- threshold_type = 'optimal_overall' , forced_choice = None , balanced_acc = False ):
63-
64- """ Calculate Receiver Operating Characteristic plot (ROC) for single-interval
65- classification.
67+ def calculate (self , input_values = None , binary_outcome = None ,
68+ criterion_values = None , threshold_type = 'optimal_overall' ,
69+ forced_choice = None , balanced_acc = False ):
70+
71+ """ Calculate Receiver Operating Characteristic plot (ROC) for
72+ single-interval classification.
6673
6774 Args:
6875 input_values: nibabel data instance
6976 binary_outcome: vector of training labels
70- criterion_values: (optional) criterion values for calculating fpr & tpr
71- threshold_type: ['optimal_overall', 'optimal_balanced','minimum_sdt_bias']
72- forced_choice: index indicating position for each unique subject (default=None)
73- balanced_acc: balanced accuracy for single-interval classification (bool)
74- **kwargs: Additional keyword arguments to pass to the prediction algorithm
77+ criterion_values: (optional) criterion values for calculating fpr
78+ & tpr
79+ threshold_type: ['optimal_overall', 'optimal_balanced',
80+ 'minimum_sdt_bias']
81+ forced_choice: index indicating position for each unique subject
82+ (default=None)
83+ balanced_acc: balanced accuracy for single-interval classification
84+ (bool)
85+ **kwargs: Additional keyword arguments to pass to the prediction
86+ algorithm
7587
7688 """
89+
7790 if input_values is not None :
7891 self .input_values = deepcopy (input_values )
7992
@@ -84,14 +97,17 @@ def calculate(self, input_values=None, binary_outcome=None, criterion_values=Non
8497 if criterion_values is not None :
8598 self .criterion_values = deepcopy (criterion_values )
8699 else :
87- self .criterion_values = np .linspace (min (self .input_values ), max (self .input_values ), num = 50 * len (self .binary_outcome ))
100+ self .criterion_values = np .linspace (min (self .input_values ),
101+ max (self .input_values ),
102+ num = 50 * len (self .binary_outcome ))
88103
89104 if forced_choice is not None :
90105 self .forced_choice = deepcopy (forced_choice )
91106
92107 if self .forced_choice is not None :
93108 sub_idx = np .unique (self .forced_choice )
94- assert len (sub_idx ) == len (self .binary_outcome )/ 2 , "Make sure that subject ids are correct for 'forced_choice'."
109+ assert len (sub_idx ) == len (self .binary_outcome )/ 2 , ("Make sure "
110+ "that subject ids are correct for 'forced_choice'." )
95111 assert len (set (sub_idx ).union (set (np .array (self .forced_choice )[self .binary_outcome ]))) == len (sub_idx ), "Issue with forced_choice subject labels."
96112 assert len (set (sub_idx ).union (set (np .array (self .forced_choice )[~ self .binary_outcome ]))) == len (sub_idx ), "Issue with forced_choice subject labels."
97113 for sub in sub_idx :
@@ -109,17 +125,7 @@ def calculate(self, input_values=None, binary_outcome=None, criterion_values=Non
109125 self .fpr [i ] = np .sum (wh [~ self .binary_outcome ])/ np .sum (~ self .binary_outcome )
110126 self .n_true = np .sum (self .binary_outcome )
111127 self .n_false = np .sum (~ self .binary_outcome )
112-
113- # Calculate Area Under the Curve
114-
115- # fix for AUC = 1 if no overlap - code not working (tpr_unique and fpr_unique can be different lengths)
116- # fpr_unique = np.unique(self.fpr)
117- # tpr_unique = np.unique(self.tpr)
118- # if any((fpr_unique == 0) & (tpr_unique == 1)):
119- # self.auc = 1 # Fix for AUC = 1 if no overlap;
120- # else:
121- # self.auc = auc(self.fpr, self.tpr) # Use sklearn auc otherwise
122- self .auc = auc (self .fpr , self .tpr ) # Use sklearn auc
128+ self .auc = auc (self .fpr , self .tpr )
123129
124130 # Get criterion threshold
125131 if self .forced_choice is None :
@@ -169,16 +175,17 @@ def plot(self, plot_method = 'gaussian'):
169175 """ Create ROC Plot
170176
171177 Create a specific kind of ROC curve plot, based on input values
172- along a continuous distribution and a binary outcome variable (logical).
178+ along a continuous distribution and a binary outcome variable (logical)
173179
174180 Args:
175181 plot_method: type of plot ['gaussian','observed']
176182 binary_outcome: vector of training labels
177- **kwargs: Additional keyword arguments to pass to the prediction algorithm
183+ **kwargs: Additional keyword arguments to pass to the prediction
184+ algorithm
178185
179186 Returns:
180187 fig
181-
188+
182189 """
183190
184191 self .calculate () # Calculate ROC parameters
@@ -188,7 +195,7 @@ def plot(self, plot_method = 'gaussian'):
188195 sub_idx = np .unique (self .forced_choice )
189196 diff_scores = []
190197 for sub in sub_idx :
191- diff_scores .append (self .input_values [(self .forced_choice == sub ) & (self .binary_outcome == True )][0 ] - self .input_values [(self .forced_choice == sub ) & (self .binary_outcome == False )][0 ])
198+ diff_scores .append (self .input_values [(self .forced_choice == sub ) & (self .binary_outcome == True )][0 ] - self .input_values [(self .forced_choice == sub ) & (self .binary_outcome == False )][0 ])
192199 diff_scores = np .array (diff_scores )
193200 mn_diff = np .mean (diff_scores )
194201 d = mn_diff / np .std (diff_scores )
@@ -221,9 +228,7 @@ def plot(self, plot_method = 'gaussian'):
221228 return fig
222229
223230 def summary (self ):
224- """ Display a formatted summary of ROC analysis.
225-
226- """
231+ """ Display a formatted summary of ROC analysis. """
227232
228233 print ("------------------------" )
229234 print (".:ROC Analysis Summary:." )
@@ -236,6 +241,3 @@ def summary(self):
236241 print ("{:20s}" .format ("AUC:" ) + "{:.2f}" .format (self .auc ))
237242 print ("{:20s}" .format ("PPV:" ) + "{:.2f}" .format (self .ppv ))
238243 print ("------------------------" )
239-
240-
241-
0 commit comments