-
Notifications
You must be signed in to change notification settings - Fork 1
/
NbX_predict.py
32 lines (22 loc) · 792 Bytes
/
NbX_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import glob
import numpy as np
import pandas as pd
import joblib
# Load
feature_df = pd.read_csv("NbX_feature.csv")
features = np.array(feature_df.loc[:,"epitope_positive_count":"chainH_MSWHIM3.1"])
model_list = glob.glob("./model/model*")
model_list.sort(reverse=True)
for model in model_list:
# Predict
xgb = joblib.load(model)
pred_proba = xgb.predict_proba(features)[:,1]
# Assign
model_name = model.split("/")[2]
feature_df.insert(1, model_name + "_predicted_CAPRI_binary_proba", pred_proba)
# Assign (mean)
proba_col = [col for col in feature_df.columns if '_predicted_CAPRI_binary_proba' in col]
mean_proba = feature_df[proba_col].mean(axis=1)
feature_df.insert(1, "mean_predicted_CAPRI_binary_proba", mean_proba)
# Save
feature_df.to_csv("NbX_prediction.csv")