From 824956336ae47182206311c3df0a39629e8c62b8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 23 Sep 2024 13:02:48 +0000 Subject: [PATCH 1/4] Online DPO optimization --- open_instruct/online_dpo_vllm_thread.py | 61 +++++++++++++++---------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/open_instruct/online_dpo_vllm_thread.py b/open_instruct/online_dpo_vllm_thread.py index 162c704ea..9c8542e88 100644 --- a/open_instruct/online_dpo_vllm_thread.py +++ b/open_instruct/online_dpo_vllm_thread.py @@ -387,6 +387,9 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + reward_model_tokenizer = AutoTokenizer.from_pretrained( + args.reward_model_path, revision=args.reward_model_revision, padding_side="right" + ) # create the dataset dataset_dict = DatasetDict() @@ -451,12 +454,11 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): attn_implementation="flash_attention_2", use_cache=False, ) + different_rm_vocab = reward_model_tokenizer.vocab_size != tokenizer.vocab_size if policy.config.vocab_size != reward_model.config.vocab_size: - raise ValueError( - "Policy and reward model must have the same vocab size. " + print( + "Policy and reward model have different vocab size. " f"Policy: {policy.config.vocab_size}, Reward: {reward_model.config.vocab_size}. " - "If they don't have the same vocab size, the policy could generate tokens which " - "is going to cause index out of bound error in the reward model." ) model = policy if model_config.gradient_checkpointing: @@ -475,7 +477,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): num_warmup_steps=args.warm_up_steps, num_training_steps=args.num_training_steps * args.num_train_epochs, ) - data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id, keep_messages=different_rm_vocab) dataloader = DataLoader( train_dataset, batch_size=args.local_dataloader_batch_size, @@ -606,6 +608,8 @@ def repeat_generator(): queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) queries_next = queries_next.repeat(args.num_generation_per_prompt, 1) send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) + if different_rm_vocab: + data["messages"] for _ in range(1, resume_training_step): # we didn't store scheduler state scheduler.step() @@ -614,6 +618,8 @@ def repeat_generator(): episode += args.batch_size scheduler.step() queries = queries_next + if different_rm_vocab: + pass if ph.preemptied: break @@ -643,6 +649,8 @@ def repeat_generator(): data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) queries_next = queries_next.repeat(args.num_generation_per_prompt, 1) + if different_rm_vocab: + data["messages"] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) else: if training_step != 1: @@ -653,13 +661,14 @@ def repeat_generator(): queries_next = queries_next.repeat(args.num_generation_per_prompt, 1) send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next + if different_rm_vocab: + data["messages"] training_time_start = time.time() with torch.no_grad(): context_length = queries.shape[1] responses = [] postprocessed_responses = [] - logprobs = [] ref_logprobs = [] scores = [] sequence_lengths = [] @@ -688,13 +697,6 @@ def repeat_generator(): query = queries[i : i + args.local_rollout_forward_batch_size] query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] - output = forward(generation_model, query_response, tokenizer.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del output, logits, all_logprob - torch.cuda.empty_cache() ref_output = forward(ref_model, query_response, tokenizer.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] @@ -720,19 +722,17 @@ def repeat_generator(): responses.append(response) postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) ref_logprobs.append(ref_logprob) sequence_lengths.append(sequence_length) scores.append(score) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) global_scores = accelerator.gather(scores) accelerator.print(f"global_scores: {global_scores}, {global_scores.mean()}") - del (logprob, ref_logprob, score) + del (ref_logprob, score) gc.collect() torch.cuda.empty_cache() @@ -751,16 +751,8 @@ def repeat_generator(): # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - # 4. compute rewards - kl = logprobs - ref_logprobs - print(f"{accelerator.local_process_index=}, {kl.sum(1)=}") - non_score_reward = -args.beta * kl - non_score_reward_sum = non_score_reward.sum(1) - rlhf_reward = scores + non_score_reward_sum - # num_examples should be same as args.local_batch_size divided by 2 num_examples = scores.size(0) // 2 first_half = scores[:num_examples] @@ -775,6 +767,8 @@ def repeat_generator(): ) scores_margin = scores[chosen_indices] - scores[rejected_indices] + logprobs = [] + concat_indices = [] # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch for epoch_idx in range(args.num_epochs): b_inds = np.random.permutation(args.local_batch_size // args.num_generation_per_prompt) @@ -801,6 +795,7 @@ def repeat_generator(): rejected_responses = responses[rejected_mb_inds] concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) + concat_indices.append(concat_mb_inds) concat_query_responses = query_responses[concat_mb_inds] concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) num_examples = chosen_mb_inds.shape[0] @@ -852,6 +847,15 @@ def repeat_generator(): optimizer.step() optimizer.zero_grad() with torch.no_grad(): + if epoch_idx == 0: + response = concat_query_responses[:, context_length:] + logits = concat_output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + logprob = torch.masked_fill(logprob, padding_mask[concat_mb_inds], INVALID_LOGPROB) + logprobs.append(logprob) + del all_logprob chosen_rewards = args.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) rejected_rewards = args.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss @@ -879,6 +883,15 @@ def repeat_generator(): # del everything and empty cache torch.cuda.empty_cache() with torch.no_grad(): + logprobs = torch.cat(logprobs, 0) + concat_indices = torch.cat(concat_indices, 0) + restore_logprobs = torch.zeros_like(logprobs) + restore_logprobs[concat_indices] = logprobs + kl = restore_logprobs - ref_logprobs + print(f"{accelerator.local_process_index=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum local_metrics[0] = sequence_lengths.float().mean() local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() local_metrics[2] = kl.sum(1).mean() From 361e95f17437e56fc6bf2793d5e977f534db94d7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 23 Sep 2024 13:10:45 +0000 Subject: [PATCH 2/4] Add a note in the docs --- docs/algorithms/online_dpo.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/algorithms/online_dpo.md b/docs/algorithms/online_dpo.md index 40e16e89b..354fe9fe9 100644 --- a/docs/algorithms/online_dpo.md +++ b/docs/algorithms/online_dpo.md @@ -348,6 +348,8 @@ These are relevant implementation details on reward modeling: 1. Truncate responses at the stop token: we truncate the responses at the `--stop_token eos` to ensure the generation is stopped at the stop token. 1. Non-stop penalty: we use a non-stop penalty to the reward model to penalize the model for not stopping at the stop token. For example, if the model does not end at the stop token, we penalize the model by `-10.0` (see `--penalty_reward_value -10.0`). 1. Async training and generation: we follow the architecture in https://arxiv.org/abs/2310.00036 to do rollout and training asynchronously. This is to ensure that the training is not bottlenecked by the generation. +1. We also optimizes online DPO runtime by re-using the model training logprob to save an additional forward pass; notice that this does impact KL calculation and causes some numerical issues. See https://github.com/allenai/open-instruct/pull/364 for more detail. + ```python import queue From 0c0171a746de62fd791d41f3db6d48c34adfae73 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 25 Sep 2024 17:11:54 +0000 Subject: [PATCH 3/4] online DPO with RM of different vocab --- open_instruct/online_dpo_vllm_thread.py | 51 ++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/open_instruct/online_dpo_vllm_thread.py b/open_instruct/online_dpo_vllm_thread.py index 9c8542e88..3f25a90f2 100644 --- a/open_instruct/online_dpo_vllm_thread.py +++ b/open_instruct/online_dpo_vllm_thread.py @@ -387,9 +387,20 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + reward_model_config = AutoConfig.from_pretrained( + args.reward_model_path, revision=args.reward_model_revision, num_labels=1 + ) reward_model_tokenizer = AutoTokenizer.from_pretrained( args.reward_model_path, revision=args.reward_model_revision, padding_side="right" ) + if reward_model_config.architectures == "LlamaForCausalLM" and reward_model_config.bos_token_id == 128000: + reward_model_tokenizer.pad_token_id = 128002 + else: + reward_model_tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + # if reward model does not have a chat template, use the same as the policy model + if reward_model_tokenizer.chat_template is None: + reward_model_tokenizer.chat_template = tokenizer.chat_template # create the dataset dataset_dict = DatasetDict() @@ -609,7 +620,7 @@ def repeat_generator(): queries_next = queries_next.repeat(args.num_generation_per_prompt, 1) send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) if different_rm_vocab: - data["messages"] + next_messages = data["messages"] for _ in range(1, resume_training_step): # we didn't store scheduler state scheduler.step() @@ -619,7 +630,8 @@ def repeat_generator(): scheduler.step() queries = queries_next if different_rm_vocab: - pass + messages = next_messages + messages = messages + messages if ph.preemptied: break @@ -650,7 +662,7 @@ def repeat_generator(): queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) queries_next = queries_next.repeat(args.num_generation_per_prompt, 1) if different_rm_vocab: - data["messages"] + next_messages = data["messages"] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) else: if training_step != 1: @@ -662,7 +674,8 @@ def repeat_generator(): send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next if different_rm_vocab: - data["messages"] + messages = data["messages"] + messages = messages + messages training_time_start = time.time() with torch.no_grad(): @@ -716,9 +729,30 @@ def repeat_generator(): # Response Processing 2. run reward model on the truncated responses postprocessed_query_response = torch.cat((query, postprocessed_response), 1) sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) + + if different_rm_vocab: + # we need to + # otherwise we get `<|endoftext|>[PAD][PAD][PAD][PAD]<|endoftext|>` + response_txts = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True) + reward_model_tokens = [] + for j in range(i, i + args.local_rollout_forward_batch_size): + messages[j][-1]["content"] = response_txts[j - i] + reward_model_tokens.append(reward_model_tokenizer.apply_chat_template(messages[j])) + + # right pad the reward model tokens + max_reward_model_len = max(len(item) for item in reward_model_tokens) + reward_model_tokens = [ + item + [reward_model_tokenizer.pad_token_id] * (max_reward_model_len - len(item)) + for item in reward_model_tokens + ] + reward_model_tokens = torch.tensor(reward_model_tokens, device=device) + _, score, _ = get_reward( + reward_model, reward_model_tokens, reward_model_tokenizer.pad_token_id, 0 + ) + else: + _, score, _ = get_reward( + reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) responses.append(response) postprocessed_responses.append(postprocessed_response) @@ -795,7 +829,6 @@ def repeat_generator(): rejected_responses = responses[rejected_mb_inds] concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) - concat_indices.append(concat_mb_inds) concat_query_responses = query_responses[concat_mb_inds] concat_output = forward(model, concat_query_responses, tokenizer.pad_token_id) num_examples = chosen_mb_inds.shape[0] @@ -848,6 +881,7 @@ def repeat_generator(): optimizer.zero_grad() with torch.no_grad(): if epoch_idx == 0: + concat_indices.append(concat_mb_inds) response = concat_query_responses[:, context_length:] logits = concat_output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 @@ -888,7 +922,6 @@ def repeat_generator(): restore_logprobs = torch.zeros_like(logprobs) restore_logprobs[concat_indices] = logprobs kl = restore_logprobs - ref_logprobs - print(f"{accelerator.local_process_index=}, {kl.sum(1)=}") non_score_reward = -args.beta * kl non_score_reward_sum = non_score_reward.sum(1) rlhf_reward = scores + non_score_reward_sum From 53422487164522015853dfbe924ffb5269551a0c Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 25 Sep 2024 17:17:46 +0000 Subject: [PATCH 4/4] deal with dataset processor --- open_instruct/dataset_processor.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 692e5358e..04aef19d4 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -471,8 +471,9 @@ def __call__(self, batch: List[Dict[str, int]]): class SimpleGenerateCollator: """Simple collator for generation task (always pad from the LEFT)""" - def __init__(self, pad_token_id: int): + def __init__(self, pad_token_id: int, keep_messages: bool = False): self.pad_token_id = pad_token_id + self.keep_messages = keep_messages def __call__(self, batch: list[dict]): """the input will have input_ids_prompt""" @@ -484,7 +485,8 @@ def __call__(self, batch: list[dict]): # Initialize lists to store padded sequences and attention masks padded_sequences = [] - + if self.keep_messages: + messages = [] for i in range(len(batch)): # Calculate padding length pad_length = max_length - len(batch[i][INPUT_IDS_PROMPT_KEY]) @@ -493,13 +495,15 @@ def __call__(self, batch: list[dict]): padding = [self.pad_token_id] * pad_length padded_sequence = padding + batch[i][INPUT_IDS_PROMPT_KEY] padded_sequences.append(padded_sequence) + if self.keep_messages: + messages.append(batch[i]["messages"]) # Convert to tensors padded_sequences = torch.tensor(padded_sequences) - - return { - INPUT_IDS_PROMPT_KEY: padded_sequences, - } + res = {INPUT_IDS_PROMPT_KEY: padded_sequences} + if self.keep_messages: + res["messages"] = messages + return res if __name__ == "__main__": @@ -511,3 +515,16 @@ def __call__(self, batch: list[dict]): # too much data; it should use all available CPUs assert get_num_proc(1000000, 120, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU) == 120 + + collator = SimpleGenerateCollator(pad_token_id=0) + batch = [ + { + INPUT_IDS_PROMPT_KEY: [1, 2, 3], + "messages": ["hello", "world", "ixxi"], + }, + { + INPUT_IDS_PROMPT_KEY: [1, 2, 3, 4], + "messages": ["hello", "world2"], + }, + ] + collated_batch = collator(batch)