-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtune_with_ax.py
executable file
·157 lines (133 loc) · 4.72 KB
/
tune_with_ax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import train
from ax.service.ax_client import AxClient
import time
import sys
import util as u
import random
import json
import tensorflow as tf
raise Exception("need to port to pod version")
# tf.config.experimental.set_visible_devices([], "GPU")
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--group', type=str, default=None,
help='wandb group. if none, no logging')
parser.add_argument('--mode', type=str, required=True,
help='mode; one of siso, simo, simo_ld or mimo')
parser.add_argument('--num-models', type=int, default=1)
parser.add_argument('--run-time-sec', type=int, default=60 * 10)
parser.add_argument('--epochs', type=int, default=60)
cmd_line_opts = parser.parse_args()
print(cmd_line_opts, file=sys.stderr)
if cmd_line_opts.mode not in ['siso', 'simo', 'simo_ld', 'mimo']:
raise Exception("invalid --mode")
# note: if running on display GPU you probably want to run set env var
# something like XLA_PYTHON_CLIENT_MEM_FRACTION=.8 to allow jobs tuned too
# large to faily cleanly with OOM
# we tune for 4 major configuration combos; see train.py for more info.
# a) SISO --input-mode=single --num-models=1
# this is the baseline non ensemble config.
#
# b) SIMO --input-mode=single --num-models=M
# models single set of inputs and labels. ensemble outputs are summed at
# logits summed to produce single output.
#
# c) SIMO_LD --input-mode=single --num-models=M
# models single set of inputs and labels. ensemble outputs are summed at
# logits summed, after dropout, to produce single output.
#
# d) MIMO --input-mode=multiple --num-models=M
# multiple inputs (with multiple labels) going through multiple models.
# loss is still averaged over all models though.
ax_params = [
# {
# "name": "max_conv_size",
# "type": "range",
# "bounds": [8, 256],
# },
# {
# "name": "dense_kernel_size",
# "type": "range",
# "bounds": [8, 128],
# },
{
"name": "learning_rate",
"type": "range",
"bounds": [1e-4, 1e-1],
"log_scale": True,
},
# {
# "name": "batch_size",
# "type": "choice",
# "values": [32, 64],
# },
]
# if cmd_line_opts.mode in ['simo', 'mimo']:
# ax_params.append({
# "name": "num_models",
# "type": "range",
# "bounds": [2, 8],
# })
ax = AxClient()
ax.create_experiment(
name="ensemble_net_tuning",
parameters=ax_params,
objective_name="final_loss",
minimize=True,
)
u.ensure_dir_exists("logs/%s" % cmd_line_opts.group)
log = open("logs/%s/ax_trials.tsv" % cmd_line_opts.group, "w")
print("trial_index\tparameters\truntime\tfinal_loss", file=log)
end_time = time.time() + cmd_line_opts.run_time_sec
while time.time() < end_time:
parameters, trial_index = ax.get_next_trial()
log_record = [trial_index, json.dumps(parameters)]
print("starting", log_record)
class Opts(object):
pass
opts = Opts()
opts.group = cmd_line_opts.group
opts.seed = random.randint(0, 1e9)
if cmd_line_opts.mode == 'siso':
opts.input_mode = 'single'
opts.num_models = 1
opts.model_dropout = False # N/A for multi_input
elif cmd_line_opts.mode == 'simo':
opts.input_mode = 'single'
opts.num_models = cmd_line_opts.num_models
opts.model_dropout = False
elif cmd_line_opts.mode == 'simo_ld':
opts.input_mode = 'single'
opts.num_models = cmd_line_opts.num_models
opts.model_dropout = True
else: # mimo
opts.input_mode = 'multiple'
opts.num_models = cmd_line_opts.num_models
opts.model_dropout = False # N/A for multi_input
opts.max_conv_size = parameters['max_conv_size']
opts.dense_kernel_size = parameters['dense_kernel_size']
opts.batch_size = 64 # parameters['batch_size']
opts.learning_rate = parameters['learning_rate']
opts.epochs = cmd_line_opts.epochs # max to run, we also use early stopping
# run
start_time = time.time()
# final_loss = train.train_in_subprocess(opts)
final_loss = train.train(opts)
log_record.append(time.time() - start_time)
log_record.append(final_loss)
# complete trial
if final_loss is None:
print("ax trial", trial_index, "failed?")
ax.log_trial_failure(trial_index=trial_index)
else:
ax.complete_trial(trial_index=trial_index,
raw_data={'final_loss': (final_loss, 0)})
print("CURRENT_BEST", ax.get_best_parameters())
# flush log
log_msg = "\t".join(map(str, log_record))
print(log_msg, file=log)
print(log_msg)
log.flush()
# save ax state
ax.save_to_json_file()