-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtest_dagan_with_matchingclassifier_for_generation.py
98 lines (79 loc) · 6.02 KB
/
test_dagan_with_matchingclassifier_for_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from generation_builder_with_matchingclassifier import ExperimentBuilder
from utils.parser_util import get_args
batch_size, num_gpus ,support_num, args = get_args()
# set the data provider to use for the experiment
if args.is_all_test_categories > 0:
###### generating images for all test categories
import data_with_matchingclassifier_for_quality_and_classifier as dataset
else:
###### generating images for those selected categories
import data_for_selecteddata_for_generation as dataset
if args.dataset == 'omniglot':
print('omniglot')
data = dataset.OmniglotDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes, image_size=args.image_width)
elif args.dataset == 'vggface':
print('vggface')
data = dataset.VGGFaceDAGANDataset(batch_size=batch_size, last_training_class_index=1600, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'miniimagenet':
print('miniimagenet')
data = dataset.miniImagenetDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'emnist':
print('emnist')
data = dataset.emnistDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'figr':
print('figr')
data = dataset.FIGRDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'fc100':
data = dataset.FC100DAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'animals':
data = dataset.animalsDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'flowers':
data = dataset.flowersDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'flowersselected':
data = dataset.flowersselectedDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
elif args.dataset == 'birds':
data = dataset.birdsDAGANDataset(batch_size=batch_size, last_training_class_index=900, reverse_channels=True,
num_of_gpus=num_gpus, gen_batches=1000, support_number=support_num
,is_training=args.is_training
,general_classification_samples=args.general_classification_samples
,selected_classes=args.selected_classes ,image_size=args.image_width)
# init experiment
experiment = ExperimentBuilder(args, data=data)
# run experiment
experiment.run_experiment()