Skip to content

Commit 8a0ef0e

Browse files
authored
Merge pull request #137 from ljchang/bug_fix
fixed bug with cross-validation in predict Former-commit-id: 9de625c
2 parents a7126b6 + a50c72f commit 8a0ef0e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

nltools/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def predict(self, algorithm=None, cv_dict=None, plot=True, **kwargs):
745745

746746
for train, test in cv:
747747
predictor_cv.fit(self.data[train], self.Y.loc[train])
748-
output['yfit_xval'][test] = predictor_cv.predict(self.data[test])
748+
output['yfit_xval'][test] = predictor_cv.predict(self.data[test]).ravel()
749749
if predictor_settings['prediction_type'] == 'classification':
750750
if predictor_settings['algorithm'] not in ['svm', 'ridgeClassifier', 'ridgeClassifierCV']:
751751
output['prob_xval'][test] = predictor_cv.predict_proba(self.data[test])[:, 1]

0 commit comments

Comments
 (0)