-
Notifications
You must be signed in to change notification settings - Fork 0
/
qgnn_infer.py
40 lines (30 loc) · 1.09 KB
/
qgnn_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from acorn_pennylane import GraphDataset, InteractionGNN
from qgnn_infer_utils import predict_step
import sys
import yaml
from torch_geometric.loader import DataLoader
import time
import torch
import yappi
with open(sys.argv[1], "r") as stream:
hparams = (yaml.load(stream, Loader=yaml.FullLoader))
### model_path should be .pt or .pth file
### scored_graphs_path is folder to store scored graphs in
model_path = sys.argv[2]
scored_graph_path = sys.argv[3]
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')
print(device, '\n')
### loading data and model
test_set = GraphDataset(input_dir = '../module_map/testset', hparams = hparams)
print(test_set[0])
test_loader = DataLoader(test_set, batch_size = 1, num_workers= 0)
model = InteractionGNN(hparams,qnn=True).to(device)
model.load_state_dict(torch.load(f'{model_path}'))
print(model)
model.eval()
### score and save each test event
for i, batch in enumerate(test_loader):
batch = batch.to(device)
predict_step(model,batch,test_loader, sys.argv[3])
print(f'test graph {i+1} scored')