-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPhaMer.py
executable file
·139 lines (102 loc) · 4.32 KB
/
PhaMer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torch.utils.data as Data
#from sklearn.model_selection import KFold
import numpy as np
import pandas as pd
import pickle as pkl
import argparse
from model import Transformer
#from sklearn.metrics import classification_report
#from sklearn.metrics import precision_score, recall_score
parser = argparse.ArgumentParser(description="""PhaMer is a python library for identifying bacteriophages from metagenomic data.
PhaMer is based on a Transorfer model and rely on protein-based vocabulary to convert DNA sequences into sentences.""")
parser.add_argument('--dbdir', help='database directory (optional)', default = 'database')
parser.add_argument('--out', help='name of the output file', type=str, default = 'out/example_prediction.csv')
parser.add_argument('--reject', help='threshold to reject prophage', type=float, default = 0.3)
parser.add_argument('--midfolder', help='folder to store the intermediate files', type=str, default='phamer/')
parser.add_argument('--threads', help='number of threads to use', type=int, default=8)
inputs = parser.parse_args()
db_dir = inputs.dbdir
if not os.path.exists(db_dir):
print(f'Database directory {db_dir} missing or unreadable')
exit(1)
out_dir = os.path.dirname(inputs.out)
if out_dir != '':
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
transformer_fn = inputs.midfolder
pcs2idx = pkl.load(open(f'{transformer_fn}/pc2wordsid.dict', 'rb'))
num_pcs = len(set(pcs2idx.keys()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == 'cpu':
print("running with cpu")
torch.set_num_threads(inputs.threads)
src_pad_idx = 0
src_vocab_size = num_pcs+1
def reset_model():
model = Transformer(
src_vocab_size,
src_pad_idx,
device=device,
max_length=300,
dropout=0.1
).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.BCEWithLogitsLoss()
return model, optimizer, loss_func
def return_batch(train_sentence, label, flag):
X_train = torch.from_numpy(train_sentence).to(device)
y_train = torch.from_numpy(label).float().to(device)
train_dataset = Data.TensorDataset(X_train, y_train)
training_loader = Data.DataLoader(
dataset=train_dataset,
batch_size=200,
shuffle=flag,
num_workers=0,
)
return training_loader
def return_tensor(var, device):
return torch.from_numpy(var).to(device)
def reject_prophage(all_pred, weight):
all_pred = np.array(all_pred.detach().cpu())
all_pred[weight < inputs.reject] = 0
return all_pred
# training with short contigs
model, optimizer, loss_func = reset_model()
try:
pretrained_dict=torch.load(f'{db_dir}/transformer.pth', map_location=device)
model.load_state_dict(pretrained_dict)
except:
print('cannot find pre-trained model')
exit()
####################################################################################
########################## train with contigs ################################
####################################################################################
sentence = pkl.load(open(f'{transformer_fn}/sentence.feat', 'rb'))
id2contig = pkl.load(open(f'{transformer_fn}/sentence_id2contig.dict', 'rb'))
proportion = pkl.load(open(f'{transformer_fn}/sentence_proportion.feat', 'rb'))
all_pred = []
all_score = []
with torch.no_grad():
_ = model.eval()
for idx in range(0, len(sentence), 500):
try:
batch_x = sentence[idx: idx+500]
weight = proportion[idx: idx+500]
except:
batch_x = sentence[idx:]
weight = proportion[idx:]
batch_x = return_tensor(batch_x, device).long()
logit = model(batch_x)
logit = torch.sigmoid(logit.squeeze(1))
logit = reject_prophage(logit, weight)
pred = ['phage' if item > 0.5 else 'non-phage' for item in logit]
all_pred += pred
all_score += [float('{:.3f}'.format(i)) for i in logit]
pred_csv = pd.DataFrame({"Contig":id2contig.values(), "Pred":all_pred, "Score":all_score})
pred_csv.to_csv(inputs.out, index = False)