Skip to content

Commit 774b1c5

Browse files
committed
added new Prediction class
1 parent ef813d4 commit 774b1c5

File tree

11 files changed

+5447
-693
lines changed

11 files changed

+5447
-693
lines changed

.gitignore

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*.pyc
44
doc/_build/
55
*.log
6+
nltools.egg-info
67

78
# Logs and databases #
89
######################
@@ -12,7 +13,8 @@ doc/_build/
1213

1314
# iPython Notebook Caches #
1415
###########################
15-
scripts/ilearn_cache/
16+
scripts/nilearn_cache/
17+
.ipynb_checkpoints
1618

1719
# OS generated files #
1820
######################
@@ -21,5 +23,5 @@ scripts/ilearn_cache/
2123
._*
2224
.Spotlight-V100
2325
.Trashes
24-
ehthumbs.db
26+
thumbs.db
2527
Thumbs.db

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ["nltools"]

build/lib/nltools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ['analysis']

build/lib/nltools/analysis.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
'''
2+
NeuroLearn Analysis Tools
3+
=========================
4+
These tools provide the ability to quickly run
5+
machine-learning analyses on imaging data
6+
Authors: Luke Chang
7+
License: MIT
8+
'''
9+
10+
# ToDo
11+
# 1) add roc functionality for classification
12+
# 2) add thresholding functionality
13+
# 3) add bootstrapping functionality
14+
# 4) add tests
15+
16+
import os
17+
import nibabel as nib
18+
import sklearn
19+
from nilearn.input_data import NiftiMasker
20+
import pandas as pd
21+
import numpy as np
22+
from nilearn.plotting import *
23+
import seaborn as sns
24+
25+
# Paths
26+
resource_dir = os.path.join(os.path.dirname(__file__),os.path.pardir,'resources')
27+
28+
29+
class Predict:
30+
31+
def __init__(self, data, Y, subject_id = None, algorithm=None, cv=None, mask=None,
32+
output_dir='.', **kwargs):
33+
""" Initialize Predict.
34+
Args:
35+
data: nibabel data instance
36+
Y: vector of training labels
37+
subject_id: vector of labels corresponding to each subject
38+
algorithm: Algorithm to use for prediction. Must be one of 'svm', 'svr',
39+
'linear', 'logistic', 'lasso', 'ridge', 'ridgeClassifier','randomforest',
40+
or 'randomforestClassifier'
41+
cv: Type of cross_validation to use. Either a string or an (uninitialized)
42+
scikit-learn cv object. If string, must be one of 'kfold' or 'loso'.
43+
mask: binary nibabel mask
44+
output_dir: Directory to use for writing all outputs
45+
**kwargs: Additional keyword arguments to pass to the prediction algorithm
46+
47+
"""
48+
self.output_dir = output_dir
49+
50+
if mask is not None:
51+
if type(mask) is not nib.nifti1.Nifti1Image:
52+
raise ValueError("mask is not a nibabel instance")
53+
self.mask = mask
54+
else:
55+
self.mask = nib.load(os.path.join(resource_dir,'MNI152_T1_2mm_brain_mask_dil.nii.gz'))
56+
57+
if type(data) is not nib.nifti1.Nifti1Image:
58+
raise ValueError("data is not a nibabel instance")
59+
nifti_masker = NiftiMasker(mask_img=mask)
60+
self.data = nifti_masker.fit_transform(data)
61+
62+
# Could check if running classification or prediction for Y
63+
if self.data.shape[0]!= len(Y):
64+
raise ValueError("Y does not match the correct size of data")
65+
self.Y = Y
66+
67+
self.set_algorithm(algorithm, **kwargs)
68+
69+
if cv is not None:
70+
self.set_cv(cv, **kwargs)
71+
72+
if subject_id is not None:
73+
self.subject_id = subject_id
74+
75+
76+
def predict(self, algorithm=None, save_images=True, save_output=True,
77+
save_plot = True, **kwargs):
78+
""" Run prediction
79+
Args:
80+
algorithm: Algorithm to use for prediction. Must be one of 'svm', 'svr',
81+
'linear', 'logistic', 'lasso', 'ridge', 'ridgeClassifier','randomforest',
82+
or 'randomforestClassifier'
83+
save_images: Boolean indicating whether or not to save images to file.
84+
save_output: Boolean indicating whether or not to save prediction output to file.
85+
save_plot: Boolean indicating whether or not to create plots.
86+
**kwargs: Additional keyword arguments to pass to the prediction algorithm
87+
"""
88+
89+
if algorithm is not None:
90+
self.set_algorithm(algorithm, **kwargs)
91+
92+
# Overall Fit for weight map
93+
predicter = self.predicter
94+
predicter.fit(self.data, self.Y)
95+
96+
if save_images:
97+
self._save_image(predicter)
98+
99+
if cv is not None:
100+
predicter_cv = self.predicter
101+
self.xval_dist_from_hyperplane = np.array(len(self.Y))
102+
for train, test in cv:
103+
predicter_cv.fit(self.data[train], self.Y[train])
104+
self.yfit[test] = self.predict(self.data[test])
105+
if algorithm is 'svm':
106+
self.xval_dist_from_hyperplane[test] = predicter_cv.decision_function(self.data[test])
107+
108+
if save_output:
109+
stats = pd.DataFrame({
110+
'SubID' : self.subject_id,
111+
'Y' : self.Y,
112+
'yfit' : self.yfit,
113+
'xval_dist_from_hyperplane' : self.xval_dist_from_hyperplane})
114+
self._save_stats_output(stats)
115+
116+
if self.prediction_type is 'classification':
117+
self.mcr = np.mean(self.yfit==self.Y)
118+
print 'overall CV accuracy: %.2f' % self.mcr
119+
elif self.prediction_type is 'prediction':
120+
self.rmse = np.sqrt(np.mean((self.yfit-self.Y)**2))
121+
self.r = np.corrcoef(Y,yfit)[0,1]
122+
print 'overall Root Mean Squared Error: %.2f' % self.rmse
123+
print 'overall Correlation: %.2f' % self.r
124+
125+
if save_plot:
126+
self._save_plot
127+
128+
129+
def set_algorithm(self, algorithm, **kwargs):
130+
""" Set the algorithm to use in subsequent prediction analyses.
131+
Args:
132+
algorithm: The prediction algorithm to use. Either a string or an (uninitialized)
133+
scikit-learn prediction object. If string, must be one of 'svm', 'svr',
134+
'linear', 'logistic', 'lasso', 'ridge', 'ridgeClassifier','randomforest',
135+
or 'randomforestClassifier'
136+
kwargs: Additional keyword arguments to pass onto the scikit-learn clustering
137+
object.
138+
"""
139+
140+
self.algorithm = algorithm
141+
142+
if isinstance(algorithm, basestring):
143+
144+
algs_classify = {
145+
'svm': sklearn.svm.SVC,
146+
'logistic': sklearn.linear_model.LogisticRegression,
147+
'ridgeClassifier': sklearn.linear_model.RidgeClassifier,
148+
'randomforestClassifier': sklearn.ensemble.RandomForestClassifier
149+
}
150+
algs_predict = {
151+
'svr': sklearn.svm.SVR,
152+
'linear': sklearn.linear_model.LinearRegression,
153+
'lasso': sklearn.linear_model.Lasso,
154+
'ridge': sklearn.linear_model.Ridge,
155+
'randomforest': sklearn.ensemble.RandomForestClassifier
156+
}
157+
if algorithm in algs_classify.keys():
158+
self.prediction_type = 'classification'
159+
elif algorithm in algs_predict.keys():
160+
self.prediction_type = 'prediction'
161+
else:
162+
raise ValueError("Invalid prediction algorithm name. Valid options are " +
163+
"'svm','svr', 'linear', 'logistic', 'lasso', 'ridge', 'ridgeClassifier'" +
164+
"'randomforest', or 'randomforestClassifier'.")
165+
166+
algorithm = algs[algorithm]
167+
168+
self.predicter = algorithm(**kwargs)
169+
170+
171+
def set_cv(self, cv, **kwargs):
172+
""" Set the CV algorithm to use in subsequent prediction analyses.
173+
Args:
174+
cv: Type of cross_validation to use. Either a string or an (uninitialized)
175+
scikit-learn cv object. If string, must be one of 'kfold' or 'loso'.
176+
**kwargs: Additional keyword arguments to pass onto the scikit-learn cv object.
177+
"""
178+
179+
self.cv_type = cv
180+
181+
if isinstance(cv, basestring):
182+
183+
cvs = {
184+
'kfold': sklearn.cross_validation.StratifiedKFold,
185+
'loso': sklearn.cross_validation.LeaveOneLabelOut,
186+
}
187+
188+
if cv not in cvs.keys():
189+
raise ValueError("Invalid cv name. Valid options are 'kfold' or 'loso'.")
190+
elif cv is 'kfold':
191+
if n_fold not in kwargs:
192+
raise ValueError("Make sure you specify n_fold when using 'kfold' cv.")
193+
194+
cv = cvs[cv]
195+
196+
self.cv = cv(**kwargs)
197+
198+
199+
def _save_image(self, predicter):
200+
""" Write out weight map to Nifti image.
201+
Args:
202+
predicter: predicter instance
203+
Outputs:
204+
predicter_weightmap.nii.gz: Will output a nifti image of weightmap
205+
"""
206+
207+
if not isdir(self.output_dir):
208+
os.makedirs(self.output_dir)
209+
210+
coef_img = nifti_masker.inverse_transform(predicter.coef_)
211+
nib.save(coef_img, os.path.abspath(self.output_dir, self.algorithm + '_weightmap.nii.gz'))
212+
213+
214+
def _save_stats_output(self, stats_output):
215+
""" Write stats output to csv file.
216+
Args:
217+
stats_output: a pandas file with prediction output
218+
Outputs:
219+
predicter_stats_output.csv: Will output a csv file of stats output
220+
"""
221+
222+
if not isdir(self.output_dir):
223+
os.makedirs(self.output_dir)
224+
stats_output.to_csv(os.path.join(self.output_dir, self.algorithm + '_Stats_Output.csv'))
225+
226+
227+
def _save_plot(self, predicter):
228+
""" Save Plots.
229+
Args:
230+
predicter: predicter instance
231+
Outputs:
232+
predicter_weightmap_montage.png: Will output a montage of axial slices of weightmap
233+
predicter_prediction.png: Will output a plot of prediction
234+
"""
235+
236+
if not isdir(self.output_dir):
237+
os.makedirs(self.output_dir)
238+
239+
coef_img = nifti_masker.inverse_transform(predicter.coef_)
240+
overlay_img = nib.load(os.path.join(resource_dir,'MNI152_T1_2mm_brain.nii.gz'))
241+
242+
fig1 = plot_stat_map(coef_img, overlay_img, title=algorithm + "weights",
243+
cut_coords=range(-40, 40, 10), display_mode='z')
244+
fig1.savefig(os.path.join(self.output_dir, self.algorithm + '_weightmap_axial.png'))
245+
246+
if self.prediction_type == 'classification':
247+
if self.algorithm == 'svm':
248+
fig2 = _dist_from_hyperplane_plot(self,stats_output)
249+
fig2.savefig(os.path.join(self.output_dir, self.algorithm +
250+
'_xVal_Distance_from_Hyperplane.png'))
251+
elif self.prediction_type == 'prediction':
252+
fig2 = _scatterplot(self,stats_output)
253+
fig2.savefig(os.path.join(self.output_dir, self.algorithm + '_scatterplot.png'))
254+
255+
256+
def _dist_from_hyperplane_plot(self,stats_output):
257+
""" Save Plots.
258+
Args:
259+
stats_output: a pandas file with prediction output
260+
Returns:
261+
fig: Will return a seaborn plot of distance from hyperplane
262+
"""
263+
264+
fig = sns.factorplot("SubID", "xval_dist_from_hyperplane", hue="Y", data=stats_output,
265+
kind='point')
266+
plt.xlabel('Subject')
267+
plt.ylabel('Distance from Hyperplane')
268+
plt.title(self.algorithm + ' Classification')
269+
return fig
270+
271+
272+
def _scatterplot(self,stats_output):
273+
""" Save Plots.
274+
Args:
275+
Returns:
276+
fig: Will return a seaborn scatterplot
277+
"""
278+
279+
fig = sns.lmplot("Y", "yfit", data=stats_out)
280+
plt.xlabel('Y')
281+
plt.ylabel('yfit')
282+
plt.title(self.algorithm + ' Prediction')
283+
return fig
284+
285+

dist/nltools-0.1-py2.7.egg

8.17 KB
Binary file not shown.

nltools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ['analysis']

0 commit comments

Comments
 (0)