forked from seokhokang/graphvae_approx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
48 lines (32 loc) · 930 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import pickle as pkl
from sklearn.preprocessing import StandardScaler
from GVAE import Model
import sys
data = sys.argv[1]
if data=='QM9':
atom_list=['C','N','O','F']
elif data=='ZINC':
atom_list=['C','N','O','F','P','S','Cl','Br','I']
data_path = './'+data+'_graph.pkl'
save_path = './'+data+'_model.ckpt'
print(':: load data')
with open(data_path,'rb') as f:
[DV, DE, DY, Dsmi] = pkl.load(f)
DV = DV.todense()
DE = DE.todense()
n_node = DV.shape[1]
dim_node = DV.shape[2]
dim_edge = DE.shape[3]
dim_y = DY.shape[1]
print(':: preprocess data')
scaler = StandardScaler()
scaler.fit(DY)
DY = scaler.transform(DY)
mu_prior=np.mean(DY,0)
cov_prior=np.cov(DY.T)
model = Model(n_node, dim_node, dim_edge, dim_y, mu_prior, cov_prior)
print(':: train model')
with model.sess:
load_path=None
model.train(DV, DE, DY, Dsmi, atom_list, load_path, save_path)