Skip to content

Commit d69a96e

Browse files
authored
Merge pull request #14 from airaria/v0.1.10dev
V0.1.10dev
2 parents 153be24 + 2c50bbc commit d69a96e

11 files changed

+134
-31
lines changed

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)
2828

2929
## Update
3030

31+
**Jul 14, 2020**
32+
* Updated to 0.1.10:
33+
* Now supports mixed precision training with Apex! Just set `fp16` to `True` in `TrainingConfig`. See the documentation of `TrainingConfig` for detail.
34+
* Added `data_parallel` option in `TrainingConfig` to enable data parallel training and mixed precision training work together.
35+
3136
**Apr 26, 2020**
3237

3338
* Added Chinese NER task (MSRA NER) results.
@@ -116,6 +121,7 @@ See [Full Documentation](https://textbrewer.readthedocs.io/) for detailed usages
116121
* NumPy
117122
* tqdm
118123
* Transformers >= 2.0 (optional, used by some examples)
124+
* Apex == 0.1.0 (optional, mixed precision training)
119125

120126
* Install from PyPI
121127

@@ -392,7 +398,6 @@ We recommend that users use pre-trained student models whenever possible to full
392398

393399
## Known Issues
394400

395-
* Compatibility with FP16 training has not been tested.
396401
* Multi-GPU training support is only available through `DataParallel` currently.
397402

398403
## Citation

README_ZH.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)
2727

2828
## 更新
2929

30+
**Jul 14, 2020**
31+
* **版本更新至0.1.10**:
32+
* 支持apex混合精度训练功能:可通过在`TrainingConfig`中设置`fp16=True`启用。详细设置参见`TraningConfig`的说明。
33+
*`TrainingConfig`中增加了`data_parallel`选项,使得数据并行与混合精度训练可同时启用。
34+
3035
**Apr 26, 2020**
3136

3237
* 增加了中文NER任务(MSRA NER)上的实验结果。
@@ -115,6 +120,7 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)
115120
* NumPy
116121
* tqdm
117122
* Transformers >= 2.0 (可选, Transformer相关示例需要用到)
123+
* Apex == 0.1.0 (可选,用于混合精度训练)
118124

119125
### 安装方式
120126

@@ -381,7 +387,6 @@ Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
381387

382388
## 已知问题
383389

384-
* FP16精度训练的兼容性尚未测试。
385390
* 尚不支持DataParallel以外的多卡训练策略。
386391

387392
## 引用

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
setup(
3131
name="textbrewer",
32-
version="0.1.9",
32+
version="0.1.10",
3333
author="ziqingyang",
3434
author_email="[email protected]",
3535
description="PyTorch-based knowledge distillation toolkit for natural language processing",

src/textbrewer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.9"
1+
__version__ = "0.1.10"
22

33
from .distillers import BasicTrainer
44
from .distillers import BasicDistiller

src/textbrewer/compatibility.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,12 @@
33
if torch.__version__ < '1.2':
44
mask_dtype = torch.uint8
55
else:
6-
mask_dtype = torch.bool
6+
mask_dtype = torch.bool
7+
8+
def is_apex_available():
9+
try:
10+
from apex import amp
11+
_has_apex = True
12+
except ImportError:
13+
_has_apex = False
14+
return _has_apex

src/textbrewer/configurations.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,24 @@ class TrainingConfig(Config):
4343
ckpt_steps (int): if *num_steps* is passes to ``distiller.train()``, saves the model every **ckpt_steps**, meanwhile ignore `ckpt_frequency` and `ckpt_epoch_frequency` .
4444
log_dir (str): directory to save the tensorboard log file. Set it to ``None`` to disable tensorboard.
4545
output_dir (str): directory to save model weights.
46-
device (str or torch.device) : training on CPU or GPU.
47-
46+
device (str or torch.device): training on CPU or GPU.
47+
fp16 (bool): if ``True``, enables mixed precision training using Apex.
48+
fp16_opt_level(str): Pure or mixed precision optimization level. Accepted values are "O0", "O1", "O2", and "O3". See Apex documenation for details.
49+
data_parallel (bool): If ``True``, wraps the models with ``torch.nn.DataParallel``.
50+
Note:
51+
* To perform data parallel training, you could either wrap the models with ``torch.nn.DataParallel`` outside TextBrewer by yourself, or leave the work for TextBrewer by setting **data_parallel** to ``True``.
52+
* To enable both data parallel training and mixed precision training, you should set **data_parallel** to ``True``, and DO NOT wrap the models by yourself.
53+
* In some experiments, we have observed the slowing down in the speed with ``torch.nn.DataParallel``. In the future we will move to DistributedDataParallel.
4854
Example::
4955
5056
# Usually just need to set log_dir and output_dir and leave others default
5157
train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir)
5258
53-
# Stores model at the end of each epoch
59+
# Stores the model at the end of each epoch
5460
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=1)
55-
# Stores model twice (at the middle and at the end) in each epoch
61+
# Stores the model twice (at the middle and at the end) in each epoch
5662
train_config = TrainingConfig(ckpt_frequency=2, ckpt_epoch_frequency=1)
57-
# Stores model once every two epochs
63+
# Stores the model once every two epochs
5864
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=2)
5965
6066
"""
@@ -64,7 +70,10 @@ def __init__(self,gradient_accumulation_steps = 1,
6470
ckpt_steps = None,
6571
log_dir = None,
6672
output_dir = './saved_models',
67-
device = 'cuda'
73+
device = 'cuda',
74+
fp16 = False,
75+
fp16_opt_level = 'O1',
76+
data_parallel = False
6877
):
6978
super(TrainingConfig, self).__init__()
7079

@@ -75,6 +84,9 @@ def __init__(self,gradient_accumulation_steps = 1,
7584
self.log_dir = log_dir
7685
self.output_dir = output_dir
7786
self.device = device
87+
self.fp16 = fp16
88+
self.fp16_opt_level = fp16_opt_level
89+
self.data_parallel = data_parallel
7890

7991
if not os.path.exists(self.output_dir):
8092
os.makedirs(self.output_dir)

src/textbrewer/distillation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ def __init__(self, train_config,
120120

121121
def save_and_callback(self,global_step, step, epoch, callback):
122122
logger.info(f"Saving at global step {global_step}, epoch step {step + 1} epoch {epoch+1}")
123-
coreModel = self.model_S.module if \
124-
'DataParallel' in self.model_S.__class__.__name__ else self.model_S
123+
coreModel = self.model.module if hasattr(self.model, "module") else self
125124
state_dict = coreModel.state_dict()
126125
torch.save(state_dict, os.path.join(self.t_config.output_dir, f"gs{global_step}.pkl"))
127126
if callback is not None:

src/textbrewer/distiller_basic.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def __init__(self, train_config,
2525

2626
def save_and_callback(self,global_step, step, epoch, callback):
2727
logger.info(f"Saving at global step {global_step}, epoch step {step + 1} epoch {epoch+1}")
28-
coreModel = self.model_S.module if \
29-
'DataParallel' in self.model_S.__class__.__name__ else self.model_S
28+
coreModel = self.model_S.module if hasattr(self.model_S, "module") else self.model_S
3029
state_dict = coreModel.state_dict()
3130
torch.save(state_dict, os.path.join(self.t_config.output_dir, f"gs{global_step}.pkl"))
3231
if callback is not None:
@@ -77,6 +76,23 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
7776
# overwrite scheduler
7877
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)
7978

79+
if self.t_config.fp16:
80+
if not has_apex:
81+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
82+
if isinstance(self.model_T,(list,tuple)):
83+
models = [self.model_S] + list(self.model_T)
84+
models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level)
85+
self.model_S = models[0]
86+
self.model_T =models[1:]
87+
else:
88+
(self.model_S, self.model_T), optimizer = amp.initialize([self.model_S, self.model_T], optimizer, opt_level=self.t_config.fp16_opt_level)
89+
if self.t_config.data_parallel:
90+
self.model_S = torch.nn.DataParallel(self.model_S)
91+
if isinstance(self.model_T,(list,tuple)):
92+
self.model_T = [torch.nn.DataParallel(model_t) for model_t in self.model_T]
93+
else:
94+
self.model_T = torch.nn.DataParallel(self.model_T)
95+
8096
if num_steps is not None:
8197
if self.d_config.is_caching_logits is True:
8298
logger.warning("is_caching_logits is True, but num_steps is not None!")
@@ -96,14 +112,21 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
96112
batch = batch_postprocessor(batch)
97113
total_loss = self.train_on_batch(batch,args)
98114
total_loss /= self.t_config.gradient_accumulation_steps
99-
total_loss.backward()
115+
if self.t_config.fp16:
116+
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
117+
scaled_loss.backward()
118+
else:
119+
total_loss.backward()
100120

101121
self.write_loss(total_loss, writer_step)
102122
writer_step += 1
103123

104124
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
105125
if max_grad_norm > 0:
106-
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
126+
if self.t_config.fp16:
127+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
128+
else:
129+
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
107130
optimizer.step()
108131
if scheduler is not None:
109132
scheduler.step()
@@ -153,14 +176,21 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
153176
batch = batch_postprocessor(batch)
154177
total_loss = self.train_on_batch(batch,args)
155178
total_loss /= self.t_config.gradient_accumulation_steps
156-
total_loss.backward()
179+
if self.t_config.fp16:
180+
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
181+
scaled_loss.backward()
182+
else:
183+
total_loss.backward()
157184

158185
self.write_loss(total_loss, writer_step)
159186
writer_step += 1
160187

161188
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
162189
if max_grad_norm > 0:
163-
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
190+
if self.t_config.fp16:
191+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
192+
else:
193+
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
164194
optimizer.step()
165195
if scheduler is not None:
166196
scheduler.step()

src/textbrewer/distiller_multitask.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
5656
# overwrite scheduler
5757
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)
5858

59+
if self.t_config.fp16:
60+
if not has_apex:
61+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
62+
tasknames, model_Ts = zip(*self.model_T.items())
63+
models = [self.model_S] + list(model_Ts)
64+
models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level)
65+
self.model_S = models[0]
66+
self.model_T = dict(zip(tasknames,models[1:]))
67+
if self.t_config.data_parallel:
68+
self.model_S = torch.nn.DataParallel(self.model_S)
69+
self.model_T = {k:torch.nn.DataParallel(v) for k,v in self.model_T.items()}
70+
5971
total_global_steps = num_steps
6072
ckpt_steps =self.t_config.ckpt_steps
6173
print_every = ckpt_steps // self.print_freq
@@ -93,13 +105,19 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
93105
batch_taskname = (batch, taskname)
94106
total_loss = self.train_on_batch(batch_taskname, args)
95107
total_loss /= self.t_config.gradient_accumulation_steps
96-
total_loss.backward()
97-
108+
if self.t_config.fp16:
109+
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
110+
scaled_loss.backward()
111+
else:
112+
total_loss.backward()
98113
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
99114
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
100115
writer_step += 1
101116
if max_grad_norm > 0:
102-
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
117+
if self.t_config.fp16:
118+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
119+
else:
120+
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
103121
optimizer.step()
104122
if scheduler is not None:
105123
scheduler.step()

src/textbrewer/distiller_train.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
4141
# overwrite scheduler
4242
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)
4343

44+
if self.t_config.fp16:
45+
if not has_apex:
46+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
47+
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=self.t_config.fp16_opt_level)
48+
49+
#dataparallel multi-gpu training
50+
if self.t_config.data_parallel:
51+
self.model = torch.nn.DataParallel(self.model)
52+
4453
if num_steps is not None:
4554
total_global_steps = num_steps
4655
ckpt_steps =self.t_config.ckpt_steps
@@ -58,15 +67,22 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
5867
batch = batch_postprocessor(batch)
5968
total_loss = self.train_on_batch(batch,args)
6069
total_loss /= self.t_config.gradient_accumulation_steps
61-
total_loss.backward()
70+
if self.t_config.fp16:
71+
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
72+
scaled_loss.backward()
73+
else:
74+
total_loss.backward()
6275

6376
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
6477
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
6578
writer_step += 1
6679

6780
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
6881
if max_grad_norm > 0:
69-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
82+
if self.t_config.fp16:
83+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
84+
else:
85+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
7086
optimizer.step()
7187
if scheduler is not None:
7288
scheduler.step()
@@ -76,8 +92,7 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
7692
logger.info(f"Global step: {global_step}, epoch step:{step+1}")
7793
if (global_step%ckpt_steps==0) or global_step==total_global_steps:
7894
logger.info(f"Saving at global step {global_step}")
79-
coreModel = self.model.module if \
80-
'DataParallel' in self.model.__class__.__name__ else self.model
95+
coreModel = self.model.module if hasattr(self.model, "module") else self.model
8196
state_dict = coreModel.state_dict()
8297
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
8398
if callback is not None:
@@ -106,15 +121,22 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
106121
batch = batch_postprocessor(batch)
107122
total_loss = self.train_on_batch(batch,args)
108123
total_loss /= self.t_config.gradient_accumulation_steps
109-
total_loss.backward()
124+
if self.t_config.fp16:
125+
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
126+
scaled_loss.backward()
127+
else:
128+
total_loss.backward()
110129

111130
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
112131
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
113132
writer_step += 1
114133

115134
if (step+1)%self.t_config.gradient_accumulation_steps == 0:
116135
if max_grad_norm > 0:
117-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
136+
if self.t_config.fp16:
137+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
138+
else:
139+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
118140
optimizer.step()
119141
if scheduler is not None:
120142
scheduler.step()
@@ -125,8 +147,7 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
125147
if (global_step%train_steps_per_epoch in checkpoints) \
126148
and ((current_epoch+1)%self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs):
127149
logger.info(f"Saving at global step {global_step}, epoch step {step+1} epoch {current_epoch+1}")
128-
coreModel = self.model.module if \
129-
'DataParallel' in self.model.__class__.__name__ else self.model
150+
coreModel = self.model.module if hasattr(self.model, "module") else self.model
130151
state_dict = coreModel.state_dict()
131152
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
132153
if callback is not None:

src/textbrewer/distiller_utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from .presets import *
1414
from .configurations import TrainingConfig, DistillationConfig
1515
import random
16-
from .compatibility import mask_dtype
16+
from .compatibility import mask_dtype, is_apex_available
17+
18+
has_apex = is_apex_available()
19+
if has_apex:
20+
from apex import amp
21+
1722

1823
logger = logging.getLogger("Distillation")
1924
#logger.setLevel(logging.INFO)

0 commit comments

Comments
 (0)