-
Notifications
You must be signed in to change notification settings - Fork 118
/
test.py
63 lines (58 loc) · 2.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import DnCNN
from utils import *
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description="DnCNN_Test")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--logdir", type=str, default="logs", help='path of log files')
parser.add_argument("--test_data", type=str, default='Set12', help='test on Set12 or Set68')
parser.add_argument("--test_noiseL", type=float, default=25, help='noise level used on test set')
opt = parser.parse_args()
def normalize(data):
return data/255.
def main():
# Build model
print('Loading model ...\n')
net = DnCNN(channels=1, num_of_layers=opt.num_of_layers)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids).cuda()
model.load_state_dict(torch.load(os.path.join(opt.logdir, 'net.pth')))
model.eval()
# load data info
print('Loading data info ...\n')
files_source = glob.glob(os.path.join('data', opt.test_data, '*.png'))
files_source.sort()
# process data
psnr_test = 0
for f in files_source:
# image
Img = cv2.imread(f)
Img = normalize(np.float32(Img[:,:,0]))
Img = np.expand_dims(Img, 0)
Img = np.expand_dims(Img, 1)
ISource = torch.Tensor(Img)
# noise
noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=opt.test_noiseL/255.)
# noisy image
INoisy = ISource + noise
ISource, INoisy = Variable(ISource.cuda()), Variable(INoisy.cuda())
with torch.no_grad(): # this can save much memory
Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
## if you are using older version of PyTorch, torch.no_grad() may not be supported
# ISource, INoisy = Variable(ISource.cuda(),volatile=True), Variable(INoisy.cuda(),volatile=True)
# Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
psnr = batch_PSNR(Out, ISource, 1.)
psnr_test += psnr
print("%s PSNR %f" % (f, psnr))
psnr_test /= len(files_source)
print("\nPSNR on test data %f" % psnr_test)
if __name__ == "__main__":
main()