-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
executable file
·51 lines (35 loc) · 1.34 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
import sys
import argparse
import itertools
import numpy as np
import torch
from train import NNet
def main():
parser = argparse.ArgumentParser(
description='Run inference using a give PyTorch model on Yeast dataset',
)
parser.add_argument('--model', help="path to the model")
args = parser.parse_args()
print(f"Loading the model from checkoint at '{args.model}'")
model = NNet()
model.load_state_dict(torch.load(args.model))
model.eval()
for name, param in model.named_parameters():
if param.requires_grad:
print("\t", name, param.data.size())
for line in sys.stdin:
x = list(map(lambda x: float(x.split(":")[1]), line.strip().split(" ")))
x = torch.from_numpy(np.array([x])).float()
print(f"input: {x.size()}")
prediction = model.forward(x)
print(f"output: {prediction.data}")
k = 3
probs, classes = torch.topk(prediction, k=k, dim=1)
#classes_prob = list(zip(classes.data[0].tolist(), probs.data[0].tolist()))
print(f"\nTop {k} predicted class: {classes.data[0].tolist()}")
p = np.argwhere(prediction.detach()[0].numpy() > 0)
pp = list(itertools.chain(*p))
print(f"\nPredicted classes \w positive prob: {pp}")
if __name__ == "__main__":
main()