-
Notifications
You must be signed in to change notification settings - Fork 24
/
demo.py
57 lines (45 loc) · 1.74 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# RUN THIS AS INIT
from data.dataset import *
from models.net import *
from utilities.main_postprocessing import *
# CONFIG
config = utils.get_parsed_config()
test_results_root_path = utils.read_from_config(config, 'test_results_root_path')
tr_results_root_path = utils.read_from_config(config, 'tr_results_root_path')
models_root_path = utils.read_from_config(config, 'models_root_path')
dataset = TinyImgnetDataset()
model = SmallVGG9(models_root_path, dataset.input_size)
# Turn on/off
plot_SI = True
# PARAMS
img_extention = 'png' # 'eps' for latex
save_img = True
plot_seq_acc = True
plot_seq_forgetting = False
hyperparams_selection = []
label_segment_idxs = [0]
exp_name_contains = None
# INIT
method_names = []
method_data_entries = []
#############################################
# MAS METHOD
if plot_SI:
method = SI()
method_names.append(method.name)
label = None
tuning_selection = []
gridsearch_name = "reproduce"
method_data_entries.extend(
collect_gridsearch_exp_entries(test_results_root_path, tr_results_root_path, dataset, method, gridsearch_name,
model, tuning_selection, label_segment_idxs=label_segment_idxs,
exp_name_contains=exp_name_contains))
#############################################
# ANALYZE
#############################################
print(method_data_entries)
out_name = None
if save_img:
out_name = '_'.join(['DEMO', dataset.name, "(" + '_'.join(method_names) + ")", model.name])
analyze_experiments(method_data_entries, hyperparams_selection=hyperparams_selection, plot_seq_acc=plot_seq_acc,
plot_seq_forgetting=plot_seq_forgetting, save_img_parent_dir=out_name, img_extention=img_extention)