-
Notifications
You must be signed in to change notification settings - Fork 239
Classes and Functions
class textbrewer.TrainingConfig (gradient_accumulation_steps = 1, ckpt_frequency = 1, ckpt_epoch_frequency=1, ckpt_steps = None, log_dir = None, output_dir = './saved_models', device = 'cuda')
-
gradient_accumulation_steps (
int
) : accumulates gradients from several steps before doing optimization to save GPU memory consumption. It calls optimizer.step() every gradient_accumulation_steps number of backward steps. When it's set larger than 1, it will help reduce GPU memory consumption especially when the batch size is big. -
ckpt_frequency (
int
): The frequency of storing model weights, i.e. the number of times to store the model weights for each epoch. -
ckpt_epoch_frequency (
int
): stores model after how many epochs each time. For example:- ckpt_frequency=1, ckpt_epoch_frequency=1 : stores once at the end of each epoch (Default).
- ckpt_frequency=2, ckpt_epoch_frequency=1 : stores twice (at the middle and at the end) at each epoch.
- ckpt_frequency=1, ckpt_epoch_frequency=2 : stores once every two epochs.
-
ckpt_steps (
int
) : ifnum_steps
indistiller.train
is set, saves the model every ckpt_steps, meanwhileckpt_frequency
andckpt_epoch_frequency
will be ignored. -
log_dir (
str
) : directory to save the tensorboard log file. Set it to None to disable tensorboard. -
output_dir (
str
) : directory to save trained model weights. - device (str, torch.device) : training on CPU or GPU.
Example
# usually just need to set the log_dir and output_dir and leave others default
train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir)
-
(classmethod) TrainingConfig.from_json_file(json_file :
str
)- reads configuration from a json file.
-
(classmethod) TrainingConfig.from_dict(dict_object :
Dict
)- reads configuration from a dict.
class textbrewer.DistillationConfig (temperature = 4, temperature_scheduler='none', hard_label_weight = 0, hard_label_weight_scheduler='none', kd_loss_type = 'ce', kd_loss_weight=1, kd_loss_weight_scheduler='none', probability_shift=False, intermediate_matches = None)
-
temperature (
float
) :temperature for distillation. When computing loss of logits, teacher and student models' logits will be divided by the temperature. -
temperature_scheduler (
str
): Dynamically adjusts temperature. See TEMPERATURE_SCHEDULER under Presets for all available options. -
kd_loss_weight (
float)
: the weight for the loss on the 'logits' term. -
hard_label_weight (
float
) : the weight of the sum of 'losses' term. Usually 'losses' includes the losses on the ground-truth labels.If hard_label_weight>0 and adaptor has provided 'losses', then the final total loss includes:
kd_loss_weight * kd_loss + hard_label_weight * sum(losses)
-
kd_loss_weight_scheduler (
str
) and hard_label_weight_scheduler(str
): Dynamically adjust loss weights. See WEIGHT_SCHEDULER under Presets for all available options. -
kd_loss_type (
str
) : loss function for the logits. See KD_LOSS_MAP under Presets for all available options. Available options are:- 'ce': computes the cross-entropy loss of teacher and student logits.
- 'mse': computes the mean square loss of teacher and student logits.
-
intermediate_matches (
List[Dict]
orNone
) : optional. Configuration for intermediate feature matching. Each element in the list is a dict, representing a pair of matching config. The keys and values of the dict are:-
'layer_T': layer_T (int): choosing the layer_T-th layer of teacher model.
-
'layer_S': layer_S (int): choosing the layer_S-th layer of student model.
Note:
-
layer_T
andlayer_S
indicate layers in 'attention' or 'hidden' list of the returned dict of the adaptor, rather than the actual layers in the model. See the distillation config below for example. -
If the loss is FSP loss or NST loss, two layers have to be chosen from the teacher and the student respectively. In this case,
layer_T
andlayer_S
are lists of two ints. See the example below.
-
-
'feature': feature(
str
): features of intermediate layers, See FEATURES under Presets for all options. Currently supports:- 'attention' : attention matrix, of the shape (batch_size, num_heads, length,length) or (batch_size, length, length)
- 'hidden':hidden states, of the shape (batch_size, length, hidden_dim).
-
'loss' : loss(
str
) : loss function. See MATCH_LOSS_MAP under Presets for available losses. Currently includes: -
'attention_mse'
-
'attention_ce'
-
'hidden_mse'
-
'nst'
-
......
-
'weight': weight (
float
) : weight of the loss. -
'proj' : proj(
List
, optional) : when teacher and student have the same feature dimension, this is optional; otherwise this is required. It is the mapping function to match teacher and student intermediate feature dimension. It is a list, with these elements:- proj[0] (
str
): mapping function, can be 'linear', 'relu', 'tanh'. See PROJ_MAP under Presets. - proj[1] (
int
): feature dimension of student model. - proj[2] (
int
): feature dimension of teacher model. - proj[3] (
dict
): optional, provides configurations such as learning rate. If not provided, the learning rate and optimizer configurations will follow the default config of the optimizer, otherwise it will use the ones specified here.
- proj[0] (
Example
from textbrewer import DistillationConfig # basic configuration: use default values, or try different temperatures distill_config = DistillationConfig(temperature=8) # adding intermediate feature matching # under this setting, the returned dict results_T/S of adaptor_T/S should contain 'hidden' key. # The mse loss between teacher's results_T['hidden'][10] and student's results_S['hidden'][3] will be computed distill_config = DistillationConfig( temperature=8, intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden', 'loss':'hidden_mse', 'weight':1}] ) # multiple inatermediate feature matching. The teacher and the student have hidden_dim 768 and 384 respectively. distill_config = DistillationConfig( temperature = 8, intermediate_matches = [ \ {'layer_T':0, 'layer_S':0, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':4, 'layer_S':1, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':8, 'layer_S':2, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}, {'layer_T':12, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}] )
-
-
(classmethod) DistillationConfig.from_json_file(json_file :
str
)- Reads config from a json file
-
(classmethod) DistillationConfig.from_dict(dict_object :
Dict
)- Reads config from a dict
Initialize a distiller class, use its train
method to start training.
The train
methods of different distillers have the same interface.
Recommended for single-model single-task distillation.
-
class textbrewer.GeneralDistiller (train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S, custom_matches = None)
-
train_config (
TrainingConfig
): training configuration. -
distill_config (
DistillationConfig
):distillation configuration. -
model_T (
torch.nn.Module
):teacher model. -
model_S (
torch.nn.Module
):student model. -
adaptor_T (
Callable
, function):teacher model's adaptor. -
adaptor_S (
Callable
, function):student model's adaptor.-
adaptor (batch, model_outputs) -> Dict
In order to adapt to the different inputs and outputs of different models, users need to provide an adaptor. The Adaptor is a callable function that takes in two input parameters: batch and model_outputs, and return a dict.
Example
''' Suppose the model outputs are: logits, sequence_output, total_loss: class MyModel(): def forward(self, input_ids, attention_mask, labels, ...): ... return logits, sequence_output, total_loss logits: Tensor of shape (batch_size, num_classes) sequence_output: List of tensors of (batch_size, length, hidden_dim) total_loss: scalar tensor model inputs are: input_ids = batch[0] : input_ids (batch_size, length) attention_mask = batch[1] : attention_mask (batch_size, length) labels = batch[2] : labels (batch_size, num_classes) ''' def BertForGLUESimpleAdaptor(batch, model_outputs): return {'logits': (model_outputs[0],), 'hidden': model.outputs[1], 'inputs_mask': batch[1]}
-
-
custom_matches (
List
) : supports more flexible self-defined matches (testing).
-
-
textbrewer.GeneralDistiller.train (optimizer, schduler, dataloader, num_epochs, num_steps=None, callback=None, batch_postprocessor=None, **args)
- optimizer: optimizer.
- scheduler: used to adjust learning rate, optional (can be None).
- dataloader: dataset iterator.
- num_epochs (
int
): number of training epochs. - num_steps (
int
): number of training steps. If it is not None, distiller will ignore num_epochs and trains for num_steps. Dataloader can have an unkonwn size, i.e., has no__len__
attribute. Dataloader will be cycled automatically after iterating over the whole dataset. - callback (
Callable
): function called after each epoch, can be None. It is called ascallback(model=self.model_S, step = global_step)
. It can be used to do evaluation of the model at each checkpoint. - batch_postprocessor (
Callable
): a function for post-processing batches. It should take a batch and return a batch. Inside the distiller, it works like:
for batch in dataloader: # if batch_postprocessor is not None: batch = batch_postprocessor(batch) # check batch datatype # passes batch to the model and adaptors
- **args: additional arguments fed to the model.
Note:
- If the batch is a list or tuple, model is called as: model(*batch, **args). Make sure the order of elements in a batch matches the order of the arguments of the model.forward.
- If the batch is a dict, model is called as: model(**batch,**args). Make sure the keys of the batch match the arguments of the model.forward.
It performs supervised training, not distillation. It can be used for training the teacher model.
- class BasicTrainer (train_config, model, adaptor)
- train_config (
TrainingConfig
): training configuration. - model (
torch.nn.Module
): model to be trained. - adaptor (
Callable
, function):adaptor of the model.
- train_config (
- BasicTrainer.train: same as
GeneralDistiller.train
.
Performs single-model single-task distillation. It doesn't support intermediate feature matching. Can be used for debugging or testing.
- class BasicDistiller (train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)
- train_config (
TrainingConfig
): training configuration. - distill_config (
DistillationConfig
):distillation configuration. - model_T (
torch.nn.Module
):teacher model. - model_S (
torch.nn.Module
):student model. - adaptor_T (
Callable
, function):teacher model adaptor. - adaptor_S (
Callable
, function):student model adaptor.
- train_config (
- BasicDistiller.train: same as
GeneralDistiller.train
.
Multi-teacher distillation. Distill multiple teacher models (of the same tasks) into the student model. It doesn't support intermediate feature matching.
-
class MultiTeacherDistiller (train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)
- train_config (
TrainingConfig
): training configuration. - distill_config (
DistillationConfig
):distillation configuration. - model_T (
List[torch.nn.Module]
):List of teacher models. - model_S (
torch.nn.Module
):student model. - adaptor_T (
Callable
, function):teacher model's adaptor. - adaptor_S (
Callable
, function):student model's adaptor.
- train_config (
-
MultiTeacherDistiller.train: same as
GeneralDistiller.train
.
Distills multiple teachers (of different tasks) into a single student. It doesn't support intermediate feature matching.
-
class textbrewer.MultiTaskDistiller (train_config, distill_config, model_T, model_S, adaptor_T, adaptor_S)
- train_config (
TrainingConfig
): training configuration. *distill_config (DistillationConfig
):distillation configuration. - model_T (
Dict[str,torch.nn.Module]
):dict of teacher models. Keys are task names, values are teacher models. - model_S (
torch.nn.Module
):student model. - adaptor_T (
Dict[str,Callable]
):dict of teacher adaptors. Keys are task names, values are corresponding adaptors. - adaptor_S (
Dict[str,Callable]
):dict of student adaptors. Keys are task names, values are corresponding adaptors.
- train_config (
Its train
method is different from other distillers:
-
textbrewer.MultiTaskDistiller.train (optimizer, schduler, dataloaders, num_steps, tau=1, callback=None, batch_postprocessors=None, **args)
- optimizer: optimizer.
- scheduler: used to adjust learning rate, optional (can be None).
- dataloaders: dict of dataset iterator. Keys are task names, values are corresponding dataloaders.
- num_steps: number of training steps.
- tau: the probability of training on an example of a task d is proportional to |d|^tau, where |d| is the size of d's training set. If the size of any dataset is unknown, ignores tau and samples examples unifromly from each dataset.
- callback (
Callable
): function called after each epoch, can be None. It is called ascallback(model=self.model_S, step = global_step)
. It can be used to do evaluation of the model at each checkpoint. - batch_postprocessors (
Dict[Callable]
): a dict of batch_postprocessors. Keys are tasknames, values are corresponding batch_postprocessors. It should take a batch and return a batch. Inside the distiller, it works like:batch = next(dataloaders[taskname]) # if batch_postprocessors is not None: batch = batch_postprocessors[taskname](batch) # check batch datatype # passes batch to the model and adaptors
- **args: additional arguments fed to the model
-
function textbrewer.utils.display_parameters(model, max_level=None)
Display the numbers and memory usage of each module's parameters.
- model (
torch.nn.Module
) : the model to be analyzed. - max_level(
int or None
): The max level to display. Ifmax_level==None
, show all the levels.
- model (
This module provides the following data augmentation methods.
-
function textbrewer.data_utils.masking(tokens: List, p = 0.1, mask = '[MASK]') -> List
Returns a new list by replacing elements in
tokens
bymask
with probabilityp
. -
function textbrewer.data_utils.deleting(tokens: List, p = 0.1) -> List
Returns a new list by deleting elements in
tokens
with probabilityp
. -
function textbrewer.data_utils.n_gram_sampling(tokens: List, p_ng = [0.2,0.2,0.2,0.2,0.2], l_ng = [1,2,3,4,5]) -> List
Samples a length
l
froml_ng
with probability distributionp_ng
, then returns a random span of lengthl
fromtokens
. -
function textbrewer.data_utils.short_disorder(tokens: List, p = [0.9,0.1,0,0,0]) -> List
Returns a new list by disordering
tokens
with probability distributionp
at every possible position. Letabc
be a 3-gram intokens
, there are five ways to disorder, corresponding to five probability values:- abc -> abc
- abc -> bac
- abc -> cba
- abc -> cab
- abc -> bca
-
function textbrewer.data_utils.long_disorder(tokens: List, p = 0.1, length = 20) -> List
Performs a long-range disordering. If
length>1
, then swaps the two halves of each span of lengthlength
intokens
; iflength<=1
, treatslength
as the relative length of the span with respect totokens
. For example,long_disorder([0,1,2,3,4,5,6,7,8,9,10], p=1, length=0.4) # [2, 3, 0, 1, 6, 7, 4, 5, 8, 9]