You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Now supports mixed precision training with Apex! Just set `fp16` to `True` in `TrainingConfig`. See the documentation of `TrainingConfig` for detail.
34
+
* Added `data_parallel` option in `TrainingConfig` to enable data parallel training and mixed precision training work together.
35
+
31
36
**Apr 26, 2020**
32
37
33
38
* Added Chinese NER task (MSRA NER) results.
@@ -116,6 +121,7 @@ See [Full Documentation](https://textbrewer.readthedocs.io/) for detailed usages
116
121
* NumPy
117
122
* tqdm
118
123
* Transformers >= 2.0 (optional, used by some examples)
Copy file name to clipboardexpand all lines: src/textbrewer/configurations.py
+18-6
Original file line number
Diff line number
Diff line change
@@ -43,18 +43,24 @@ class TrainingConfig(Config):
43
43
ckpt_steps (int): if *num_steps* is passes to ``distiller.train()``, saves the model every **ckpt_steps**, meanwhile ignore `ckpt_frequency` and `ckpt_epoch_frequency` .
44
44
log_dir (str): directory to save the tensorboard log file. Set it to ``None`` to disable tensorboard.
45
45
output_dir (str): directory to save model weights.
46
-
device (str or torch.device) : training on CPU or GPU.
47
-
46
+
device (str or torch.device): training on CPU or GPU.
47
+
fp16 (bool): if ``True``, enables mixed precision training using Apex.
48
+
fp16_opt_level(str): Pure or mixed precision optimization level. Accepted values are "O0", "O1", "O2", and "O3". See Apex documenation for details.
49
+
data_parallel (bool): If ``True``, wraps the models with ``torch.nn.DataParallel``.
50
+
Note:
51
+
* To perform data parallel training, you could either wrap the models with ``torch.nn.DataParallel`` outside TextBrewer by yourself, or leave the work for TextBrewer by setting **data_parallel** to ``True``.
52
+
* To enable both data parallel training and mixed precision training, you should set **data_parallel** to ``True``, and DO NOT wrap the models by yourself.
53
+
* In some experiments, we have observed the slowing down in the speed with ``torch.nn.DataParallel``. In the future we will move to DistributedDataParallel.
48
54
Example::
49
55
50
56
# Usually just need to set log_dir and output_dir and leave others default
0 commit comments