forked from NVlabs/NVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
150 lines (127 loc) · 6 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# ---------------------------------------------------------------
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for NVAE. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
from time import time
from torch.multiprocessing import Process
from torch.cuda.amp import autocast
from model import AutoEncoder
import utils
import datasets
from train import test, init_processes
def set_bn(model, bn_eval_mode, num_samples=1, t=1.0, iter=100):
if bn_eval_mode:
model.eval()
else:
model.train()
with autocast():
for i in range(iter):
if i % 10 == 0:
print('setting BN statistics iter %d out of %d' % (i+1, iter))
model.sample(num_samples, t)
model.eval()
def main(eval_args):
# ensures that weight initializations are all the same
logging = utils.Logger(eval_args.local_rank, eval_args.save)
# load a checkpoint
logging.info('loading the model at:')
logging.info(eval_args.checkpoint)
checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
args = checkpoint['args']
logging.info('loaded the model at epoch %d', checkpoint['epoch'])
arch_instance = utils.get_arch_cells(args.arch_instance)
model = AutoEncoder(args, None, arch_instance)
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()
logging.info('args = %s', args)
logging.info('num conv layers: %d', len(model.all_conv_layers))
logging.info('param size = %fM ', utils.count_parameters_in_M(model))
if eval_args.eval_mode == 'evaluate':
# load train valid queue
args.data = eval_args.data
train_queue, valid_queue, num_classes = datasets.get_loaders(args)
if eval_args.eval_on_train:
logging.info('Using the training data for eval.')
valid_queue = train_queue
# get number of bits
num_output = utils.num_output(args.dataset)
bpd_coeff = 1. / np.log(2.) / num_output
valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=eval_args.num_iw_samples, args=args, logging=logging)
logging.info('final valid nelbo %f', valid_nelbo)
logging.info('final valid neg log p %f', valid_neg_log_p)
logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff)
logging.info('final valid neg log p in bpd %f', valid_neg_log_p * bpd_coeff)
else:
bn_eval_mode = not eval_args.readjust_bn
num_samples = 16
with torch.no_grad():
n = int(np.floor(np.sqrt(num_samples)))
set_bn(model, bn_eval_mode, num_samples=36, t=eval_args.temp, iter=500)
for ind in range(5): # sampling is repeated.
torch.cuda.synchronize()
start = time()
with autocast():
logits = model.sample(num_samples, eval_args.temp)
output = model.decoder_output(logits)
output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) \
else output.sample()
torch.cuda.synchronize()
end = time()
output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0)
logging.info('sampling time per batch: %0.3f sec', (end - start))
output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
output_tiled = np.squeeze(output_tiled)
plt.imshow(output_tiled)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser('encoder decoder examiner')
# experimental results
parser.add_argument('--checkpoint', type=str, default='/tmp/expr/checkpoint.pt',
help='location of the checkpoint')
parser.add_argument('--save', type=str, default='/tmp/expr',
help='location of the checkpoint')
parser.add_argument('--eval_mode', type=str, default='sample', choices=['sample', 'evaluate'],
help='evaluation mode. you can choose between sample or evaluate.')
parser.add_argument('--eval_on_train', action='store_true', default=False,
help='Settings this to true will evaluate the model on training data.')
parser.add_argument('--data', type=str, default='/tmp/data',
help='location of the data corpus')
parser.add_argument('--readjust_bn', action='store_true', default=False,
help='adding this flag will enable readjusting BN statistics.')
parser.add_argument('--temp', type=float, default=0.7,
help='The temperature used for sampling.')
parser.add_argument('--num_iw_samples', type=int, default=1000,
help='The number of IW samples used in test_ll mode.')
# DDP.
parser.add_argument('--local_rank', type=int, default=0,
help='rank of process')
parser.add_argument('--world_size', type=int, default=1,
help='number of gpus')
parser.add_argument('--seed', type=int, default=1,
help='seed used for initialization')
parser.add_argument('--master_address', type=str, default='127.0.0.1',
help='address for master')
args = parser.parse_args()
utils.create_exp_dir(args.save)
size = args.world_size
if size > 1:
args.distributed = True
processes = []
for rank in range(size):
args.local_rank = rank
p = Process(target=init_processes, args=(rank, size, main, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
# for debugging
print('starting in debug mode')
args.distributed = True
init_processes(0, size, main, args)