From 0cbbbe7824e9718aaacf8bc80717a6c0b2ac9c28 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 14 Oct 2024 12:43:07 +0200 Subject: [PATCH 1/5] initial uld loss --- trl/trainer/gkd_trainer.py | 87 +++++++++++++++++++++++++++++++++----- trl/trainer/utils.py | 79 +++++++++++++++++++++++++++++++++- 2 files changed, 155 insertions(+), 11 deletions(-) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557d..72f54775ef 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -73,6 +73,9 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, + teacher_processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, @@ -83,7 +86,10 @@ def __init__( ): # add remove_unused_columns=False to the the dataclass args args.remove_unused_columns = False - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) + data_collator = DataCollatorForChatML( + tokenizer=processing_class, max_length=args.max_seq_length, teacher_tokenizer=teacher_processing_class + ) + self.uld_loss = teacher_processing_class is not None super().__init__( model, @@ -215,6 +221,54 @@ def generalized_jsd_loss( else: return jsd + @staticmethod + def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels=None, beta=0.5, reduction="sum"): + """ + Compute the Universal Logit Distillation (ULD) loss. + + Args: + student_logits: Tensor of shape (batch_size, student_sequence_length, vocab_size) + teacher_logits: Tensor of shape (batch_size, teacher_sequence_length, vocab_size) + student_labels: Tensor of shape (batch_size, student_sequence_length) with -100 for padding tokens + teacher_labels: Tensor of shape (batch_size, teacher_sequence_length) with -100 for padding tokens + beta: Weight for the Wasserstein distance (default: 0.5) + reduction: Specifies the reduction to apply to the output (default: 'sum') + + Returns: + loss: Scalar tensor with the ULD loss + """ + # mask out logits via the student and teacher labels + mask = (student_labels != -100).float() + student_logits = student_logits * mask + mask = (teacher_labels != -100).float() + teacher_logits = teacher_logits * mask + + # Convert logits to probabilities + student_probs = F.softmax(student_logits, dim=-1) + teacher_probs = F.softmax(teacher_logits, dim=-1) + + # Sort probabilities in descending order + student_probs_sorted, _ = torch.sort(student_probs, dim=-1, descending=True) + teacher_probs_sorted, _ = torch.sort(teacher_probs, dim=-1, descending=True) + + # pad the smaller tensor to the same size as the larger tensor by zeros + max_length = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) + student_probs_sorted = F.pad(student_probs_sorted, (0, max_length - student_probs_sorted.size(1))) + teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_length - teacher_probs_sorted.size(1))) + + # Compute weighted Wasserstein distance + wasserstein_distance = beta * torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1) + + # Apply reduction + if reduction == "batchmean": + return wasserstein_distance.sum() / (wasserstein_distance.size(0) * wasserstein_distance.size(1)) + elif reduction == "sum": + return wasserstein_distance.sum() + elif reduction == "mean": + return wasserstein_distance.mean() + else: + return wasserstein_distance + def compute_loss(self, model, inputs, return_outputs=False): # compute student output outputs_student = model( @@ -224,25 +278,38 @@ def compute_loss(self, model, inputs, return_outputs=False): # compute teacher output in eval mode self.teacher_model.eval() + teacher_input_ids = inputs.get("teacher_input_ids", inputs["input_ids"]) + teacher_attention_mask = inputs.get("teacher_attention_mask", inputs["attention_mask"]) with torch.no_grad(): outputs_teacher = self.teacher_model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, ) # slice the logits for the generated tokens using the inputs["prompts"] lengths prompt_lengths = inputs["prompts"].shape[1] + teacher_prompt_lengths = inputs.get("teacher_prompts", inputs["prompts"]).shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] - shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] + shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_lengths - 1 : -1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - labels=shifted_labels, - beta=self.beta, - ) + if self.uld_loss: + shifted_teacher_labels = inputs.get("teacher_labels", inputs["labels"])[:, teacher_prompt_lengths:] + loss = self.uld_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + student_labels=shifted_labels, + teacher_labels=shifted_teacher_labels, + beta=self.beta, + ) + else: + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + labels=shifted_labels, + beta=self.beta, + ) # empty cache empty_cache() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 81e826f971..cc401f00d0 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -242,6 +242,7 @@ class DataCollatorForChatML: """ tokenizer: PreTrainedTokenizerBase + teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None ignore_index: int = -100 max_length: int = None prompt_key: str = "prompt" @@ -250,9 +251,15 @@ class DataCollatorForChatML: def __post_init__(self): if self.tokenizer.pad_token_id is None: raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + if self.teacher_tokenizer is not None and self.teacher_tokenizer.pad_token_id is None: + raise ValueError( + "The teacher tokenizer does not have a pad token. Please set `pad_token_id` in the teacher tokenizer." + ) if self.max_length is None: # set a sensible default self.max_length = min(self.tokenizer.model_max_length, 1024) + if self.teacher_tokenizer is not None: + self.max_length = min(self.teacher_tokenizer.model_max_length, self.max_length) def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: input_ids = [] @@ -261,6 +268,13 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompt_attention_mask = [] labels = [] + if self.teacher_tokenizer is not None: + teacher_input_ids = [] + teacher_attention_mask = [] + teacher_labels = [] + teacher_prompts_input_ids = [] + teacher_prompt_attention_mask = [] + for example in examples: formatted_prompt = example.get(self.prompt_key, None) if formatted_prompt is None: @@ -306,6 +320,42 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: label[completion_start_idx:] = input_ids[-1][completion_start_idx:] labels.append(label) + if self.teacher_tokenizer is not None: + message = example[self.messages_key] + formatted_teacher_message = self.teacher_tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) + tokenized_teacher_message = self.teacher_tokenizer( + formatted_teacher_message, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + teacher_input_ids.append(tokenized_teacher_message["input_ids"]) + teacher_attention_mask.append(tokenized_teacher_message["attention_mask"]) + + formatted_teacher_prompt = self.teacher_tokenizer.apply_chat_template( + example[self.messages_key][:-1], tokenize=False, add_generation_prompt=True + ) + + teacher_tokenized_prompt = self.teacher_tokenizer( + formatted_teacher_prompt, + truncation=True, + max_length=len(teacher_input_ids[-1]), + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + teacher_prompts_input_ids.append(teacher_tokenized_prompt["input_ids"]) + teacher_prompt_attention_mask.append(teacher_tokenized_prompt["attention_mask"]) + + teacher_label = [self.ignore_index] * len(teacher_input_ids[-1]) + teacher_completion_start_idx = len(teacher_tokenized_prompt["input_ids"]) + teacher_label[teacher_completion_start_idx:] = teacher_input_ids[-1][teacher_completion_start_idx:] + teacher_labels.append(teacher_label) + # convert to list of tensors and pad input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] @@ -319,7 +369,25 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) - return { + if self.teacher_tokenizer is not None: + teacher_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_input_ids] + teacher_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in teacher_attention_mask] + teacher_labels = [torch.tensor(label, dtype=torch.long) for label in teacher_labels] + teacher_input_ids = pad( + teacher_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id + ) + teacher_attention_mask = pad(teacher_attention_mask, padding_side="left", padding_value=0) + teacher_labels = pad(teacher_labels, padding_side="left", padding_value=self.ignore_index) + teacher_prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_prompts_input_ids] + teacher_prompt_attention_mask = [ + torch.tensor(mask, dtype=torch.long) for mask in teacher_prompt_attention_mask + ] + teacher_prompts_input_ids = pad( + teacher_prompts_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id + ) + teacher_prompt_attention_mask = pad(teacher_prompt_attention_mask, padding_side="left", padding_value=0) + + batch = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, @@ -327,6 +395,15 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: "prompt_attention_mask": prompt_attention_mask, } + if self.teacher_tokenizer is not None: + batch["teacher_input_ids"] = teacher_input_ids + batch["teacher_attention_mask"] = teacher_attention_mask + batch["teacher_labels"] = teacher_labels + batch["teacher_prompts"] = teacher_prompts_input_ids + batch["teacher_prompt_attention_mask"] = teacher_prompt_attention_mask + + return batch + @dataclass class RewardDataCollatorWithPadding: From 34520c38b69d411ba21f02df9318b13112cc19c1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Oct 2024 11:43:46 +0200 Subject: [PATCH 2/5] remove the beta from the loss --- trl/trainer/gkd_trainer.py | 39 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 72f54775ef..8c5a5dad0c 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -89,7 +89,7 @@ def __init__( data_collator = DataCollatorForChatML( tokenizer=processing_class, max_length=args.max_seq_length, teacher_tokenizer=teacher_processing_class ) - self.uld_loss = teacher_processing_class is not None + self.use_uld_loss = teacher_processing_class is not None super().__init__( model, @@ -222,42 +222,44 @@ def generalized_jsd_loss( return jsd @staticmethod - def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels=None, beta=0.5, reduction="sum"): + def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels=None, reduction="sum"): """ Compute the Universal Logit Distillation (ULD) loss. Args: - student_logits: Tensor of shape (batch_size, student_sequence_length, vocab_size) - teacher_logits: Tensor of shape (batch_size, teacher_sequence_length, vocab_size) + student_logits: Tensor of shape (batch_size, student_sequence_length, student_vocab_size) + teacher_logits: Tensor of shape (batch_size, teacher_sequence_length, teacher_vocab_size) student_labels: Tensor of shape (batch_size, student_sequence_length) with -100 for padding tokens teacher_labels: Tensor of shape (batch_size, teacher_sequence_length) with -100 for padding tokens - beta: Weight for the Wasserstein distance (default: 0.5) reduction: Specifies the reduction to apply to the output (default: 'sum') Returns: loss: Scalar tensor with the ULD loss """ - # mask out logits via the student and teacher labels - mask = (student_labels != -100).float() - student_logits = student_logits * mask - mask = (teacher_labels != -100).float() - teacher_logits = teacher_logits * mask - # Convert logits to probabilities student_probs = F.softmax(student_logits, dim=-1) teacher_probs = F.softmax(teacher_logits, dim=-1) + # mask out logits via the student and teacher labels + student_mask = student_labels != -100 + # student_logits = student_logits * student_mask + teacher_mask = teacher_labels != -100 + # teacher_logits = teacher_logits * teacher_mask + # Sort probabilities in descending order - student_probs_sorted, _ = torch.sort(student_probs, dim=-1, descending=True) - teacher_probs_sorted, _ = torch.sort(teacher_probs, dim=-1, descending=True) + student_probs_sorted, _ = torch.sort(student_probs[student_mask], dim=-1, descending=True) + teacher_probs_sorted, _ = torch.sort(teacher_probs[teacher_mask], dim=-1, descending=True) # pad the smaller tensor to the same size as the larger tensor by zeros - max_length = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) - student_probs_sorted = F.pad(student_probs_sorted, (0, max_length - student_probs_sorted.size(1))) - teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_length - teacher_probs_sorted.size(1))) + min_tokens = min(student_probs_sorted.size(0), teacher_probs_sorted.size(0)) + max_vocab_size = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) + student_probs_sorted = F.pad(student_probs_sorted, (0, max_vocab_size - student_probs_sorted.size(1))) + teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_vocab_size - teacher_probs_sorted.size(1))) # Compute weighted Wasserstein distance - wasserstein_distance = beta * torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1) + wasserstein_distance = torch.abs(student_probs_sorted[:min_tokens] - teacher_probs_sorted[:min_tokens]).sum( + dim=-1 + ) # Apply reduction if reduction == "batchmean": @@ -294,14 +296,13 @@ def compute_loss(self, model, inputs, return_outputs=False): shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss - if self.uld_loss: + if self.use_uld_loss: shifted_teacher_labels = inputs.get("teacher_labels", inputs["labels"])[:, teacher_prompt_lengths:] loss = self.uld_loss( student_logits=shifted_student_logits, teacher_logits=shifted_teacher_logits, student_labels=shifted_labels, teacher_labels=shifted_teacher_labels, - beta=self.beta, ) else: loss = self.generalized_jsd_loss( From 7c5502c50777e40f5bfec758bbca4a262d594921 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Oct 2024 11:49:16 +0200 Subject: [PATCH 3/5] fix comments --- trl/trainer/gkd_trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 8c5a5dad0c..4babd1c56a 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -240,23 +240,21 @@ def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels student_probs = F.softmax(student_logits, dim=-1) teacher_probs = F.softmax(teacher_logits, dim=-1) - # mask out logits via the student and teacher labels + # mask via the student and teacher labels student_mask = student_labels != -100 - # student_logits = student_logits * student_mask teacher_mask = teacher_labels != -100 - # teacher_logits = teacher_logits * teacher_mask - # Sort probabilities in descending order + # Sort probabilities in descending order of only the non-padding tokens student_probs_sorted, _ = torch.sort(student_probs[student_mask], dim=-1, descending=True) teacher_probs_sorted, _ = torch.sort(teacher_probs[teacher_mask], dim=-1, descending=True) - # pad the smaller tensor to the same size as the larger tensor by zeros - min_tokens = min(student_probs_sorted.size(0), teacher_probs_sorted.size(0)) + # pad the probabilities to the max vocab size by zero padding max_vocab_size = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) student_probs_sorted = F.pad(student_probs_sorted, (0, max_vocab_size - student_probs_sorted.size(1))) teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_vocab_size - teacher_probs_sorted.size(1))) # Compute weighted Wasserstein distance + min_tokens = min(student_probs_sorted.size(0), teacher_probs_sorted.size(0)) wasserstein_distance = torch.abs(student_probs_sorted[:min_tokens] - teacher_probs_sorted[:min_tokens]).sum( dim=-1 ) From 4eb781fb424b5d51b1a50f8cd026c040790469d3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 15 Oct 2024 14:36:32 +0200 Subject: [PATCH 4/5] align masks --- examples/scripts/gkd.py | 28 ++++++++++++++++++++++++++-- trl/trainer/gkd_trainer.py | 31 +++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 7c37d811c5..fabe5ebc7b 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -42,6 +42,21 @@ --use_peft \ --lora_r 64 \ --lora_alpha 16 + +# ULD +python examples/scripts/gkd.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --teacher_model_name_or_path google/gemma-2-2b-it \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gkd-model \ + --logging_steps 10 \ + --num_train_epochs 1 \ + --push_to_hub \ + --gradient_checkpointing \ + --torch_dtype bfloat16 """ from accelerate import PartialState @@ -75,7 +90,7 @@ attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, + device_map=get_kbit_device_map() if quantization_config is not None else "auto", quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs @@ -86,7 +101,7 @@ attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=True, - device_map=get_kbit_device_map() if quantization_config is not None else None, + device_map=get_kbit_device_map() if quantization_config is not None else "auto", quantization_config=quantization_config, ) training_args.teacher_model_init_kwargs = teacher_model_kwargs @@ -100,6 +115,14 @@ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + teacher_tokenizer = AutoTokenizer.from_pretrained( + training_args.teacher_model_name_or_path, + revision=model_config.model_revision, + trust_remote_code=model_config.trust_remote_code, + ) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + ################ # Dataset ################ @@ -123,6 +146,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, + teacher_processing_class=teacher_tokenizer, peft_config=get_peft_config(model_config), ) completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 4babd1c56a..6cf67e1b0d 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -240,24 +240,35 @@ def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels student_probs = F.softmax(student_logits, dim=-1) teacher_probs = F.softmax(teacher_logits, dim=-1) - # mask via the student and teacher labels + # Create masks for non-padding tokens student_mask = student_labels != -100 teacher_mask = teacher_labels != -100 - # Sort probabilities in descending order of only the non-padding tokens - student_probs_sorted, _ = torch.sort(student_probs[student_mask], dim=-1, descending=True) - teacher_probs_sorted, _ = torch.sort(teacher_probs[teacher_mask], dim=-1, descending=True) + # Ensure the masks have the same shape as their corresponding probabilities + student_mask = student_mask.unsqueeze(-1).expand_as(student_probs) + teacher_mask = teacher_mask.unsqueeze(-1).expand_as(teacher_probs) - # pad the probabilities to the max vocab size by zero padding + # Apply masks + student_probs_masked = student_probs.masked_select(student_mask).view(-1, student_probs.size(-1)) + teacher_probs_masked = teacher_probs.masked_select(teacher_mask).view(-1, teacher_probs.size(-1)) + + # Ensure we have the same number of tokens for both student and teacher + min_tokens = min(student_probs_masked.size(0), teacher_probs_masked.size(0)) + student_probs_masked = student_probs_masked[:min_tokens] + teacher_probs_masked = teacher_probs_masked[:min_tokens] + + # Sort probabilities in descending order + student_probs_sorted, _ = torch.sort(student_probs_masked, dim=-1, descending=True) + teacher_probs_sorted, _ = torch.sort(teacher_probs_masked, dim=-1, descending=True) + + # Pad the probabilities to the max vocab size max_vocab_size = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) student_probs_sorted = F.pad(student_probs_sorted, (0, max_vocab_size - student_probs_sorted.size(1))) teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_vocab_size - teacher_probs_sorted.size(1))) - # Compute weighted Wasserstein distance - min_tokens = min(student_probs_sorted.size(0), teacher_probs_sorted.size(0)) - wasserstein_distance = torch.abs(student_probs_sorted[:min_tokens] - teacher_probs_sorted[:min_tokens]).sum( - dim=-1 - ) + # Compute Wasserstein distance + wasserstein_distance = torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1) + # Apply reduction if reduction == "batchmean": From 3b22fd747382dbc0f5feeb45424b908f38d6116e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 22 Oct 2024 20:24:05 +0200 Subject: [PATCH 5/5] add doc --- docs/source/gkd_trainer.md | 4 ++++ trl/trainer/gkd_trainer.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index c4f82ff160..78704ab1aa 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -89,6 +89,10 @@ The dataset should be formatted as a list of "messages" where each message is a * `content`: the message content +## Universal Logit Distillation Loss + +In the case where the student and teacher vocabs differ, we can use the Universal Logit Distillation Loss (ULD) proposed in [Towards Cross-Tokenizer Distillation](https://huggingface.co/papers/22402.12030) to distill the teacher model into the student model by utilizing the Wasserstein distance between the teacher and student logit distributions. In the offline setting this loss can be used in the `GKDTrainer` by passing a `teacher_processing_class` argument to the `GKDTrainer` and setting the `lmbda=0.0` in the `GKDConfig`. This trainer for non-zero `lmbda` will perform (partial) on-policy distillation which is as far as we are aware of not mentioned in the ULD paper. + ## GKDTrainer [[autodoc]] GKDTrainer diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 6cf67e1b0d..fbd4983170 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -269,7 +269,6 @@ def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels # Compute Wasserstein distance wasserstein_distance = torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1) - # Apply reduction if reduction == "batchmean": return wasserstein_distance.sum() / (wasserstein_distance.size(0) * wasserstein_distance.size(1))