-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
113 lines (105 loc) · 4.5 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# GLOBAL PARAMETERS
import argparse
DATASETS = ['mnist', 'fmnist', 'cifar10', 'cifar100']
TRAINERS = {'fedmeta': 'FedMeta'}
OPTIMIZERS = TRAINERS.keys()
MODEL_CONFIG = {
'mnist.cnn': {'num_classes': 10, 'image_size': 28},
'fmnist.cnn': {'num_classes': 10, 'image_size': 28},
'cifar10.cnn': {'num_classes': 10, 'image_size': 32},
'cifar100.cnn': {'num_classes': 100, 'image_size': 32}
}
def base_options():
parser = argparse.ArgumentParser()
parser.add_argument('--algo',
help='name of trainer;',
type=str,
choices=OPTIMIZERS, )
parser.add_argument('--data',
help='name of data;',
type=str,
required=True)
parser.add_argument('--model',
help='name of model;',
type=str,
default='cnn')
parser.add_argument('--wd',
help='weight decay parameter;',
type=float,
default=0.001)
parser.add_argument('--device',
help='device',
default='cpu:0',
type=str)
parser.add_argument('--num_rounds',
help='number of rounds to simulate;',
type=int,
default=200)
parser.add_argument('--eval_on_test_every',
help='evaluate every ____ rounds;',
type=int,
default=1)
parser.add_argument('--eval_on_train_every',
help='evaluate every ____ rounds;',
type=int,
default=1)
parser.add_argument('--eval_on_validation_every',
help='evaluate every ____ rounds;',
type=int,
default=1)
parser.add_argument('--save_every',
help='save global model every ____ rounds;',
type=int,
default=50)
parser.add_argument('--clients_per_round',
help='number of clients trained per round;',
type=int,
default=10)
parser.add_argument('--batch_size',
help='batch size when clients train on data;',
type=int,
default=10)
parser.add_argument('--num_epochs',
help='number of epochs when clients train on data;',
type=int,
default=20)
parser.add_argument('--lr',
help='learning rate for inner solver;',
type=float,
default=0.01)
parser.add_argument('--seed',
help='seed for randomness;',
type=int,
default=0)
parser.add_argument('--quiet',
type=int,
default=0)
parser.add_argument('--result_prefix',
type=str,
default='./result')
parser.add_argument('--train_val_test',
action='store_true')
parser.add_argument('--result_dir',
type=str,
default='')
parser.add_argument('--data_format',
type=str,
default='pkl')
parser.add_argument('--train_inner_step', default=0, type=int)
parser.add_argument('--test_inner_step', default=0, type=int)
parser.add_argument('--same_mini_batch', action='store_true', default=False)
return parser
def add_dynamic_options(argparser):
params = argparser.parse_known_args()[0]
algo = params.algo
if algo in ['fedmeta']:
argparser.add_argument('--meta_algo', type=str, default='maml',
choices=['maml', 'reptile', 'meta_sgd'])
argparser.add_argument('--outer_lr', type=float, required=True)
argparser.add_argument('--meta_train_test_split', type=int, default=-1)
argparser.add_argument('--store_to_cpu', action='store_true', default=False)
argparser.add_argument('--use_pppfl', action='store_true', default=False)
argparser.add_argument('--eps_smooth_factor', type=float, default=10.0)
elif algo == 'fedavg_adv':
argparser.add_argument('--use_all_data', action='store_true', default=False)
return argparser