-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
57 lines (41 loc) · 1.45 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import torch
# from config import get_parse_args
# from utils.logger import Logger
# from utils.random_seed import set_seed
# from utils.sat_utils import solve_sat_iteratively
from satb.data.dataset_factory import _dataset_factory
# from detectors.detector_factory import detector_factory
def test():
# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str
# print(args)
# args.num_rounds = args.test_num_rounds
dataset = _dataset_factory['ckt']('data/random_sr3_10_100')
# dataset = _dataset_factory[args.dataset](args.dataset_dir, args)
# Do the shuffle
# perm = torch.randperm(len(dataset))
# dataset = dataset[perm]
# split = args.test_split
# dataset = dataset[:100]
data_len = len(dataset)
print('Total # Test SAT problems: ', data_len)
exit()
detector = detector_factory['base'](args)
print('Start Solving the SAT problem using DeepGate with Logic Implication...')
correct = 0
total = 0
for ind, g in enumerate(dataset):
if 'Mask' in g.name:
continue
total += 1
sol, sat = solve_sat_iteratively(g, detector)
print('# {} SAT: '.format(ind),sat)
if sat:
correct +=1
print('ACC: {:.2f}% ({}/{})'.format(100*correct/total, correct, total))
if __name__ == '__main__':
# args = get_parse_args()
test()