Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GKD] add ULD type loss to GKD Trainer #2263

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ The dataset should be formatted as a list of "messages" where each message is a
* `content`: the message content


## Universal Logit Distillation Loss

In the case where the student and teacher vocabs differ, we can use the Universal Logit Distillation Loss (ULD) proposed in [Towards Cross-Tokenizer Distillation](https://huggingface.co/papers/22402.12030) to distill the teacher model into the student model by utilizing the Wasserstein distance between the teacher and student logit distributions. In the offline setting this loss can be used in the `GKDTrainer` by passing a `teacher_processing_class` argument to the `GKDTrainer` and setting the `lmbda=0.0` in the `GKDConfig`. This trainer for non-zero `lmbda` will perform (partial) on-policy distillation which is as far as we are aware of not mentioned in the ULD paper.

## GKDTrainer

[[autodoc]] GKDTrainer
Expand Down
28 changes: 26 additions & 2 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@
--use_peft \
--lora_r 64 \
--lora_alpha 16

# ULD
python examples/scripts/gkd.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--teacher_model_name_or_path google/gemma-2-2b-it \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--output_dir gkd-model \
--logging_steps 10 \
--num_train_epochs 1 \
--push_to_hub \
--gradient_checkpointing \
--torch_dtype bfloat16
"""

from accelerate import PartialState
Expand Down Expand Up @@ -75,7 +90,7 @@
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
device_map=get_kbit_device_map() if quantization_config is not None else "auto",
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
Expand All @@ -86,7 +101,7 @@
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
device_map=get_kbit_device_map() if quantization_config is not None else "auto",
quantization_config=quantization_config,
)
training_args.teacher_model_init_kwargs = teacher_model_kwargs
Expand All @@ -100,6 +115,14 @@
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

teacher_tokenizer = AutoTokenizer.from_pretrained(
training_args.teacher_model_name_or_path,
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
)
if teacher_tokenizer.pad_token is None:
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

################
# Dataset
################
Expand All @@ -123,6 +146,7 @@
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
teacher_processing_class=teacher_tokenizer,
peft_config=get_peft_config(model_config),
)
completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
Expand Down
96 changes: 86 additions & 10 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
teacher_processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
Expand All @@ -83,7 +86,10 @@ def __init__(
):
# add remove_unused_columns=False to the the dataclass args
args.remove_unused_columns = False
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
data_collator = DataCollatorForChatML(
tokenizer=processing_class, max_length=args.max_seq_length, teacher_tokenizer=teacher_processing_class
)
self.use_uld_loss = teacher_processing_class is not None

super().__init__(
model,
Expand Down Expand Up @@ -215,6 +221,64 @@ def generalized_jsd_loss(
else:
return jsd

@staticmethod
def uld_loss(student_logits, teacher_logits, student_labels=None, teacher_labels=None, reduction="sum"):
"""
Compute the Universal Logit Distillation (ULD) loss.

Args:
student_logits: Tensor of shape (batch_size, student_sequence_length, student_vocab_size)
teacher_logits: Tensor of shape (batch_size, teacher_sequence_length, teacher_vocab_size)
student_labels: Tensor of shape (batch_size, student_sequence_length) with -100 for padding tokens
teacher_labels: Tensor of shape (batch_size, teacher_sequence_length) with -100 for padding tokens
reduction: Specifies the reduction to apply to the output (default: 'sum')

Returns:
loss: Scalar tensor with the ULD loss
"""
# Convert logits to probabilities
student_probs = F.softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1)

# Create masks for non-padding tokens
student_mask = student_labels != -100
teacher_mask = teacher_labels != -100

# Ensure the masks have the same shape as their corresponding probabilities
student_mask = student_mask.unsqueeze(-1).expand_as(student_probs)
teacher_mask = teacher_mask.unsqueeze(-1).expand_as(teacher_probs)

# Apply masks
student_probs_masked = student_probs.masked_select(student_mask).view(-1, student_probs.size(-1))
teacher_probs_masked = teacher_probs.masked_select(teacher_mask).view(-1, teacher_probs.size(-1))

# Ensure we have the same number of tokens for both student and teacher
min_tokens = min(student_probs_masked.size(0), teacher_probs_masked.size(0))
student_probs_masked = student_probs_masked[:min_tokens]
teacher_probs_masked = teacher_probs_masked[:min_tokens]

# Sort probabilities in descending order
student_probs_sorted, _ = torch.sort(student_probs_masked, dim=-1, descending=True)
teacher_probs_sorted, _ = torch.sort(teacher_probs_masked, dim=-1, descending=True)

# Pad the probabilities to the max vocab size
max_vocab_size = max(student_probs_sorted.size(1), teacher_probs_sorted.size(1))
student_probs_sorted = F.pad(student_probs_sorted, (0, max_vocab_size - student_probs_sorted.size(1)))
teacher_probs_sorted = F.pad(teacher_probs_sorted, (0, max_vocab_size - teacher_probs_sorted.size(1)))

# Compute Wasserstein distance
wasserstein_distance = torch.abs(student_probs_sorted - teacher_probs_sorted).sum(dim=-1)

# Apply reduction
if reduction == "batchmean":
return wasserstein_distance.sum() / (wasserstein_distance.size(0) * wasserstein_distance.size(1))
elif reduction == "sum":
return wasserstein_distance.sum()
elif reduction == "mean":
return wasserstein_distance.mean()
else:
return wasserstein_distance

def compute_loss(self, model, inputs, return_outputs=False):
# compute student output
outputs_student = model(
Expand All @@ -224,25 +288,37 @@ def compute_loss(self, model, inputs, return_outputs=False):

# compute teacher output in eval mode
self.teacher_model.eval()
teacher_input_ids = inputs.get("teacher_input_ids", inputs["input_ids"])
teacher_attention_mask = inputs.get("teacher_attention_mask", inputs["attention_mask"])
with torch.no_grad():
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
)

# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
teacher_prompt_lengths = inputs.get("teacher_prompts", inputs["prompts"]).shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]

# compute loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)
if self.use_uld_loss:
shifted_teacher_labels = inputs.get("teacher_labels", inputs["labels"])[:, teacher_prompt_lengths:]
loss = self.uld_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
student_labels=shifted_labels,
teacher_labels=shifted_teacher_labels,
)
else:
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)

# empty cache
empty_cache()
Expand Down
79 changes: 78 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class DataCollatorForChatML:
"""

tokenizer: PreTrainedTokenizerBase
teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None
ignore_index: int = -100
max_length: int = None
prompt_key: str = "prompt"
Expand All @@ -250,9 +251,15 @@ class DataCollatorForChatML:
def __post_init__(self):
if self.tokenizer.pad_token_id is None:
raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.")
if self.teacher_tokenizer is not None and self.teacher_tokenizer.pad_token_id is None:
raise ValueError(
"The teacher tokenizer does not have a pad token. Please set `pad_token_id` in the teacher tokenizer."
)
if self.max_length is None:
# set a sensible default
self.max_length = min(self.tokenizer.model_max_length, 1024)
if self.teacher_tokenizer is not None:
self.max_length = min(self.teacher_tokenizer.model_max_length, self.max_length)

def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
input_ids = []
Expand All @@ -261,6 +268,13 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
prompt_attention_mask = []
labels = []

if self.teacher_tokenizer is not None:
teacher_input_ids = []
teacher_attention_mask = []
teacher_labels = []
teacher_prompts_input_ids = []
teacher_prompt_attention_mask = []

for example in examples:
formatted_prompt = example.get(self.prompt_key, None)
if formatted_prompt is None:
Expand Down Expand Up @@ -306,6 +320,42 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
label[completion_start_idx:] = input_ids[-1][completion_start_idx:]
labels.append(label)

if self.teacher_tokenizer is not None:
message = example[self.messages_key]
formatted_teacher_message = self.teacher_tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
tokenized_teacher_message = self.teacher_tokenizer(
formatted_teacher_message,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
add_special_tokens=False,
)
teacher_input_ids.append(tokenized_teacher_message["input_ids"])
teacher_attention_mask.append(tokenized_teacher_message["attention_mask"])

formatted_teacher_prompt = self.teacher_tokenizer.apply_chat_template(
example[self.messages_key][:-1], tokenize=False, add_generation_prompt=True
)

teacher_tokenized_prompt = self.teacher_tokenizer(
formatted_teacher_prompt,
truncation=True,
max_length=len(teacher_input_ids[-1]),
padding=False,
return_tensors=None,
add_special_tokens=False,
)
teacher_prompts_input_ids.append(teacher_tokenized_prompt["input_ids"])
teacher_prompt_attention_mask.append(teacher_tokenized_prompt["attention_mask"])

teacher_label = [self.ignore_index] * len(teacher_input_ids[-1])
teacher_completion_start_idx = len(teacher_tokenized_prompt["input_ids"])
teacher_label[teacher_completion_start_idx:] = teacher_input_ids[-1][teacher_completion_start_idx:]
teacher_labels.append(teacher_label)

# convert to list of tensors and pad
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask]
Expand All @@ -319,14 +369,41 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)

return {
if self.teacher_tokenizer is not None:
teacher_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_input_ids]
teacher_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in teacher_attention_mask]
teacher_labels = [torch.tensor(label, dtype=torch.long) for label in teacher_labels]
teacher_input_ids = pad(
teacher_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id
)
teacher_attention_mask = pad(teacher_attention_mask, padding_side="left", padding_value=0)
teacher_labels = pad(teacher_labels, padding_side="left", padding_value=self.ignore_index)
teacher_prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in teacher_prompts_input_ids]
teacher_prompt_attention_mask = [
torch.tensor(mask, dtype=torch.long) for mask in teacher_prompt_attention_mask
]
teacher_prompts_input_ids = pad(
teacher_prompts_input_ids, padding_side="left", padding_value=self.teacher_tokenizer.pad_token_id
)
teacher_prompt_attention_mask = pad(teacher_prompt_attention_mask, padding_side="left", padding_value=0)

batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"prompts": prompts_input_ids,
"prompt_attention_mask": prompt_attention_mask,
}

if self.teacher_tokenizer is not None:
batch["teacher_input_ids"] = teacher_input_ids
batch["teacher_attention_mask"] = teacher_attention_mask
batch["teacher_labels"] = teacher_labels
batch["teacher_prompts"] = teacher_prompts_input_ids
batch["teacher_prompt_attention_mask"] = teacher_prompt_attention_mask

return batch


@dataclass
class RewardDataCollatorWithPadding:
Expand Down