diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md index c4f82ff160..78704ab1aa 100644 --- a/docs/source/gkd_trainer.md +++ b/docs/source/gkd_trainer.md @@ -89,6 +89,10 @@ The dataset should be formatted as a list of "messages" where each message is a * `content`: the message content +## Universal Logit Distillation Loss + +In the case where the student and teacher vocabs differ, we can use the Universal Logit Distillation Loss (ULD) proposed in [Towards Cross-Tokenizer Distillation](https://huggingface.co/papers/22402.12030) to distill the teacher model into the student model by utilizing the Wasserstein distance between the teacher and student logit distributions. In the offline setting this loss can be used in the `GKDTrainer` by passing a `teacher_processing_class` argument to the `GKDTrainer` and setting the `lmbda=0.0` in the `GKDConfig`. This trainer for non-zero `lmbda` will perform (partial) on-policy distillation which is as far as we are aware of not mentioned in the ULD paper. + ## GKDTrainer [[autodoc]] GKDTrainer diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 7c37d811c5..fabe5ebc7b 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -42,6 +42,21 @@ --use_peft \ --lora_r 64 \ --lora_alpha 16 + +# ULD +python examples/scripts/gkd.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --teacher_model_name_or_path google/gemma-2-2b-it \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gkd-model \ + --logging_steps 10 \ + --num_train_epochs 1 \ + --push_to_hub \ + --gradient_checkpointing \ + --torch_dtype bfloat16 """ from accelerate import PartialState @@ -75,7 +90,7 @@ attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, + device_map=get_kbit_device_map() if quantization_config is not None else "auto", quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs @@ -86,7 +101,7 @@ attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=True, - device_map=get_kbit_device_map() if quantization_config is not None else None, + device_map=get_kbit_device_map() if quantization_config is not None else "auto", quantization_config=quantization_config, ) training_args.teacher_model_init_kwargs = teacher_model_kwargs @@ -100,6 +115,14 @@ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + teacher_tokenizer = AutoTokenizer.from_pretrained( + training_args.teacher_model_name_or_path, + revision=model_config.model_revision, + trust_remote_code=model_config.trust_remote_code, + ) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + ################ # Dataset ################ @@ -123,6 +146,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split], processing_class=tokenizer, + teacher_processing_class=teacher_tokenizer, peft_config=get_peft_config(model_config), ) completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557d..fbd4983170 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -73,6 +73,9 @@ def __init__( processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, + teacher_processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, @@ -83,7 +86,10 @@ def __init__( ): # add remove_unused_columns=False to the the dataclass args args.remove_unused_columns = False - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) + data_collator = DataCollatorForChatML( + tokenizer=processing_class, max_length=args.max_seq_length, teacher_tokenizer=teacher_processing_class + ) + self.use_uld_loss = teacher_processing_class is not None super().__init__( model, @@ -215,6 +221,64 @@ def generalized_jsd_loss( else: return jsd + @staticmethod + def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels=None, reduction="sum"): + """ + Compute the Universal Logit Distillation (ULD) loss. + + Args: + student_logits: Tensor of shape (batch_size, student_sequence_length, student_vocab_size) + teacher_logits: Tensor of shape (batch_size, teacher_sequence_length, teacher_vocab_size) + student_labels: Tensor of shape (batch_size, student_sequence_length) with -100 for padding tokens + teacher_labels: Tensor of shape (batch_size, teacher_sequence_length) with -100 for padding tokens + reduction: Specifies the reduction to apply to the output (default: 'sum') + + Returns: + loss: Scalar tensor with the ULD loss + """ + # Convert logits to probabilities + student_probs = F.softmax(student_logits, dim=-1) + teacher_probs = F.softmax(teacher_logits, dim=-1) + + # Create masks for non-padding tokens + student_mask = student_labels != -100 + teacher_mask = teacher_labels != -100 + + # Ensure the masks have the same shape as their corresponding probabilities + student_mask = student_mask.unsqueeze(-1).expand_as(student_probs) + teacher_mask = teacher_mask.unsqueeze(-1).expand_as(teacher_probs) + + # Apply masks + student_probs_masked = student_probs.masked_select(student_mask).view(-1, student_probs.size(-1)) + teacher_probs_masked = teacher_probs.masked_select(teacher_mask).view(-1, teacher_probs.size(-1)) + + # Ensure we have the same number of tokens for both student and teacher + min_tokens = min(student_probs_masked.size(0), teacher_probs_masked.size(0)) + student_probs_masked = student_probs_masked[:min_tokens] + teacher_probs_masked = teacher_probs_masked[:min_tokens] + + # Sort probabilities in descending order + student_probs_sorted, _ = torch.sort(student_probs_masked, dim=-1, descending=True) + teacher_probs_sorted, _ = torch.sort(teacher_probs_masked, dim=-1, descending=True) + + # Pad the probabilities to the max vocab size + max_vocab_size = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1)) + student_probs_sorted = F.pad(student_probs_sorted, (0, max_vocab_size - student_probs_sorted.size(1))) + teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_vocab_size - teacher_probs_sorted.size(1))) + + # Compute Wasserstein distance + wasserstein_distance = torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1) + + # Apply reduction + if reduction == "batchmean": + return wasserstein_distance.sum() / (wasserstein_distance.size(0) * wasserstein_distance.size(1)) + elif reduction == "sum": + return wasserstein_distance.sum() + elif reduction == "mean": + return wasserstein_distance.mean() + else: + return wasserstein_distance + def compute_loss(self, model, inputs, return_outputs=False): # compute student output outputs_student = model( @@ -224,25 +288,37 @@ def compute_loss(self, model, inputs, return_outputs=False): # compute teacher output in eval mode self.teacher_model.eval() + teacher_input_ids = inputs.get("teacher_input_ids", inputs["input_ids"]) + teacher_attention_mask = inputs.get("teacher_attention_mask", inputs["attention_mask"]) with torch.no_grad(): outputs_teacher = self.teacher_model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, ) # slice the logits for the generated tokens using the inputs["prompts"] lengths prompt_lengths = inputs["prompts"].shape[1] + teacher_prompt_lengths = inputs.get("teacher_prompts", inputs["prompts"]).shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] - shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] + shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_lengths - 1 : -1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - labels=shifted_labels, - beta=self.beta, - ) + if self.use_uld_loss: + shifted_teacher_labels = inputs.get("teacher_labels", inputs["labels"])[:, teacher_prompt_lengths:] + loss = self.uld_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + student_labels=shifted_labels, + teacher_labels=shifted_teacher_labels, + ) + else: + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + labels=shifted_labels, + beta=self.beta, + ) # empty cache empty_cache() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 81e826f971..cc401f00d0 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -242,6 +242,7 @@ class DataCollatorForChatML: """ tokenizer: PreTrainedTokenizerBase + teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None ignore_index: int = -100 max_length: int = None prompt_key: str = "prompt" @@ -250,9 +251,15 @@ class DataCollatorForChatML: def __post_init__(self): if self.tokenizer.pad_token_id is None: raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + if self.teacher_tokenizer is not None and self.teacher_tokenizer.pad_token_id is None: + raise ValueError( + "The teacher tokenizer does not have a pad token. Please set `pad_token_id` in the teacher tokenizer." + ) if self.max_length is None: # set a sensible default self.max_length = min(self.tokenizer.model_max_length, 1024) + if self.teacher_tokenizer is not None: + self.max_length = min(self.teacher_tokenizer.model_max_length, self.max_length) def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: input_ids = [] @@ -261,6 +268,13 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompt_attention_mask = [] labels = [] + if self.teacher_tokenizer is not None: + teacher_input_ids = [] + teacher_attention_mask = [] + teacher_labels = [] + teacher_prompts_input_ids = [] + teacher_prompt_attention_mask = [] + for example in examples: formatted_prompt = example.get(self.prompt_key, None) if formatted_prompt is None: @@ -306,6 +320,42 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: label[completion_start_idx:] = input_ids[-1][completion_start_idx:] labels.append(label) + if self.teacher_tokenizer is not None: + message = example[self.messages_key] + formatted_teacher_message = self.teacher_tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) + tokenized_teacher_message = self.teacher_tokenizer( + formatted_teacher_message, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + teacher_input_ids.append(tokenized_teacher_message["input_ids"]) + teacher_attention_mask.append(tokenized_teacher_message["attention_mask"]) + + formatted_teacher_prompt = self.teacher_tokenizer.apply_chat_template( + example[self.messages_key][:-1], tokenize=False, add_generation_prompt=True + ) + + teacher_tokenized_prompt = self.teacher_tokenizer( + formatted_teacher_prompt, + truncation=True, + max_length=len(teacher_input_ids[-1]), + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + teacher_prompts_input_ids.append(teacher_tokenized_prompt["input_ids"]) + teacher_prompt_attention_mask.append(teacher_tokenized_prompt["attention_mask"]) + + teacher_label = [self.ignore_index] * len(teacher_input_ids[-1]) + teacher_completion_start_idx = len(teacher_tokenized_prompt["input_ids"]) + teacher_label[teacher_completion_start_idx:] = teacher_input_ids[-1][teacher_completion_start_idx:] + teacher_labels.append(teacher_label) + # convert to list of tensors and pad input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] @@ -319,7 +369,25 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) - return { + if self.teacher_tokenizer is not None: + teacher_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_input_ids] + teacher_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in teacher_attention_mask] + teacher_labels = [torch.tensor(label, dtype=torch.long) for label in teacher_labels] + teacher_input_ids = pad( + teacher_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id + ) + teacher_attention_mask = pad(teacher_attention_mask, padding_side="left", padding_value=0) + teacher_labels = pad(teacher_labels, padding_side="left", padding_value=self.ignore_index) + teacher_prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_prompts_input_ids] + teacher_prompt_attention_mask = [ + torch.tensor(mask, dtype=torch.long) for mask in teacher_prompt_attention_mask + ] + teacher_prompts_input_ids = pad( + teacher_prompts_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id + ) + teacher_prompt_attention_mask = pad(teacher_prompt_attention_mask, padding_side="left", padding_value=0) + + batch = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, @@ -327,6 +395,15 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: "prompt_attention_mask": prompt_attention_mask, } + if self.teacher_tokenizer is not None: + batch["teacher_input_ids"] = teacher_input_ids + batch["teacher_attention_mask"] = teacher_attention_mask + batch["teacher_labels"] = teacher_labels + batch["teacher_prompts"] = teacher_prompts_input_ids + batch["teacher_prompt_attention_mask"] = teacher_prompt_attention_mask + + return batch + @dataclass class RewardDataCollatorWithPadding: