-
Notifications
You must be signed in to change notification settings - Fork 239
/
main.distill.py
212 lines (183 loc) · 9.51 KB
/
main.distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger("Main")
import os,random
import numpy as np
import torch
from processing import convert_examples_to_features, read_squad_examples
from processing import ChineseFullTokenizer
from pytorch_pretrained_bert.my_modeling import BertConfig
from optimization import BERTAdam
import config
from utils import read_and_convert, divide_parameters
from modeling import BertForQASimple, BertForQASimpleAdaptor
from textbrewer import DistillationConfig, TrainingConfig, GeneralDistiller
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from functools import partial
from train_eval import predict
def args_check(args):
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
logger.warning("Output directory () already exists and is not empty.")
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not args.no_cuda else 0
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
args.n_gpu = n_gpu
args.device = device
return device, n_gpu
def main():
#parse arguments
config.parse()
args = config.args
for k,v in vars(args).items():
logger.info(f"{k}:{v}")
#set seeds
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed)
np.random.seed(args.random_seed)
random.seed(args.random_seed)
#arguments check
device, n_gpu = args_check(args)
os.makedirs(args.output_dir, exist_ok=True)
forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
args.forward_batch_size = forward_batch_size
#load bert config
bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
assert args.max_seq_length <= bert_config_T.max_position_embeddings
assert args.max_seq_length <= bert_config_S.max_position_embeddings
#read data
train_examples = None
train_features = None
eval_examples = None
eval_features = None
num_train_steps = None
tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
convert_fn = partial(convert_examples_to_features,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length)
if args.do_train:
train_examples,train_features = read_and_convert(args.train_file,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
if args.fake_file_1:
fake_examples1,fake_features1 = read_and_convert(args.fake_file_1,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples1
train_features += fake_features1
if args.fake_file_2:
fake_examples2, fake_features2 = read_and_convert(args.fake_file_2,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples2
train_features += fake_features2
num_train_steps = int(len(train_features)/args.train_batch_size) * args.num_train_epochs
if args.do_predict:
eval_examples,eval_features = read_and_convert(args.predict_file,is_training=False, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
#Build Model and load checkpoint
model_T = BertForQASimple(bert_config_T,args)
model_S = BertForQASimple(bert_config_S,args)
#Load teacher
if args.tuned_checkpoint_T is not None:
state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
model_T.load_state_dict(state_dict_T)
model_T.eval()
else:
assert args.do_predict is True
#Load student
if args.load_model_type=='bert':
assert args.init_checkpoint_S is not None
state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
assert len(missing_keys)==0
elif args.load_model_type=='all':
assert args.tuned_checkpoint_S is not None
state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
model_S.load_state_dict(state_dict_S)
else:
logger.info("Model is randomly initialized.")
model_T.to(device)
model_S.to(device)
if args.local_rank != -1 or n_gpu > 1:
if args.local_rank != -1:
raise NotImplementedError
elif n_gpu > 1:
model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1)
model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)
if args.do_train:
#parameters
params = list(model_S.named_parameters())
all_trainable_params = divide_parameters(params, lr=args.learning_rate)
logger.info("Length of all_trainable_params: %d", len(all_trainable_params))
optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
logger.info(" Forward batch size = %d", forward_batch_size)
logger.info(" Num backward steps = %d", num_train_steps)
########### DISTILLATION ###########
train_config = TrainingConfig(
gradient_accumulation_steps = args.gradient_accumulation_steps,
ckpt_frequency = args.ckpt_frequency,
log_dir = args.output_dir,
output_dir = args.output_dir,
device = args.device)
from matches import matches
intermediate_matches = None
if isinstance(args.matches,(list,tuple)):
intermediate_matches = []
for match in args.matches:
intermediate_matches += matches[match]
logger.info(f"{intermediate_matches}")
distill_config = DistillationConfig(
temperature = args.temperature,
intermediate_matches=intermediate_matches)
adaptor_T = partial(BertForQASimpleAdaptor, no_logits=args.no_logits, no_mask = args.no_inputs_mask)
adaptor_S = partial(BertForQASimpleAdaptor, no_logits=args.no_logits, no_mask = args.no_inputs_mask)
distiller = GeneralDistiller(train_config = train_config,
distill_config = distill_config,
model_T = model_T, model_S = model_S,
adaptor_T = adaptor_T,
adaptor_S = adaptor_S)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask,
all_start_positions, all_end_positions)
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
raise NotImplementedError
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
callback_func = partial(predict,
eval_examples=eval_examples,
eval_features=eval_features,
args=args)
with distiller:
distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
num_epochs=args.num_train_epochs, callback=callback_func)
if not args.do_train and args.do_predict:
res = predict(model_S,eval_examples,eval_features,step=0,args=args)
print (res)
if __name__ == "__main__":
main()