Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onlinedpo Support rm with different vocab size #368

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/algorithms/online_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ These are relevant implementation details on reward modeling:
1. Truncate responses at the stop token: we truncate the responses at the `--stop_token eos` to ensure the generation is stopped at the stop token.
1. Non-stop penalty: we use a non-stop penalty to the reward model to penalize the model for not stopping at the stop token. For example, if the model does not end at the stop token, we penalize the model by `-10.0` (see `--penalty_reward_value -10.0`).
1. Async training and generation: we follow the architecture in https://arxiv.org/abs/2310.00036 to do rollout and training asynchronously. This is to ensure that the training is not bottlenecked by the generation.
1. We also optimizes online DPO runtime by re-using the model training logprob to save an additional forward pass; notice that this does impact KL calculation and causes some numerical issues. See https://github.com/allenai/open-instruct/pull/364 for more detail.


```python
import queue
Expand Down
29 changes: 23 additions & 6 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,9 @@ def __call__(self, batch: List[Dict[str, int]]):
class SimpleGenerateCollator:
"""Simple collator for generation task (always pad from the LEFT)"""

def __init__(self, pad_token_id: int):
def __init__(self, pad_token_id: int, keep_messages: bool = False):
self.pad_token_id = pad_token_id
self.keep_messages = keep_messages

def __call__(self, batch: list[dict]):
"""the input will have input_ids_prompt"""
Expand All @@ -484,7 +485,8 @@ def __call__(self, batch: list[dict]):

# Initialize lists to store padded sequences and attention masks
padded_sequences = []

if self.keep_messages:
messages = []
for i in range(len(batch)):
# Calculate padding length
pad_length = max_length - len(batch[i][INPUT_IDS_PROMPT_KEY])
Expand All @@ -493,13 +495,15 @@ def __call__(self, batch: list[dict]):
padding = [self.pad_token_id] * pad_length
padded_sequence = padding + batch[i][INPUT_IDS_PROMPT_KEY]
padded_sequences.append(padded_sequence)
if self.keep_messages:
messages.append(batch[i]["messages"])

# Convert to tensors
padded_sequences = torch.tensor(padded_sequences)

return {
INPUT_IDS_PROMPT_KEY: padded_sequences,
}
res = {INPUT_IDS_PROMPT_KEY: padded_sequences}
if self.keep_messages:
res["messages"] = messages
return res


if __name__ == "__main__":
Expand All @@ -511,3 +515,16 @@ def __call__(self, batch: list[dict]):

# too much data; it should use all available CPUs
assert get_num_proc(1000000, 120, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU) == 120

collator = SimpleGenerateCollator(pad_token_id=0)
batch = [
{
INPUT_IDS_PROMPT_KEY: [1, 2, 3],
"messages": ["hello", "world", "ixxi"],
},
{
INPUT_IDS_PROMPT_KEY: [1, 2, 3, 4],
"messages": ["hello", "world2"],
},
]
collated_batch = collator(batch)
100 changes: 73 additions & 27 deletions open_instruct/online_dpo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,20 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding
tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template]
reward_model_config = AutoConfig.from_pretrained(
args.reward_model_path, revision=args.reward_model_revision, num_labels=1
)
reward_model_tokenizer = AutoTokenizer.from_pretrained(
args.reward_model_path, revision=args.reward_model_revision, padding_side="right"
)
if reward_model_config.architectures == "LlamaForCausalLM" and reward_model_config.bos_token_id == 128000:
reward_model_tokenizer.pad_token_id = 128002
else:
reward_model_tokenizer.add_special_tokens({"pad_token": "[PAD]"})

# if reward model does not have a chat template, use the same as the policy model
if reward_model_tokenizer.chat_template is None:
reward_model_tokenizer.chat_template = tokenizer.chat_template

# create the dataset
dataset_dict = DatasetDict()
Expand Down Expand Up @@ -451,12 +465,11 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
attn_implementation="flash_attention_2",
use_cache=False,
)
different_rm_vocab = reward_model_tokenizer.vocab_size != tokenizer.vocab_size
if policy.config.vocab_size != reward_model.config.vocab_size:
raise ValueError(
"Policy and reward model must have the same vocab size. "
print(
"Policy and reward model have different vocab size. "
f"Policy: {policy.config.vocab_size}, Reward: {reward_model.config.vocab_size}. "
"If they don't have the same vocab size, the policy could generate tokens which "
"is going to cause index out of bound error in the reward model."
)
model = policy
if model_config.gradient_checkpointing:
Expand All @@ -475,7 +488,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
num_warmup_steps=args.warm_up_steps,
num_training_steps=args.num_training_steps * args.num_train_epochs,
)
data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id)
data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id, keep_messages=different_rm_vocab)
dataloader = DataLoader(
train_dataset,
batch_size=args.local_dataloader_batch_size,
Expand Down Expand Up @@ -606,6 +619,8 @@ def repeat_generator():
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
queries_next = queries_next.repeat(args.num_generation_per_prompt, 1)
send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next)
if different_rm_vocab:
next_messages = data["messages"]

for _ in range(1, resume_training_step): # we didn't store scheduler state
scheduler.step()
Expand All @@ -614,6 +629,9 @@ def repeat_generator():
episode += args.batch_size
scheduler.step()
queries = queries_next
if different_rm_vocab:
messages = next_messages
messages = messages + messages
if ph.preemptied:
break

Expand Down Expand Up @@ -643,6 +661,8 @@ def repeat_generator():
data = next(iter_dataloader)
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
queries_next = queries_next.repeat(args.num_generation_per_prompt, 1)
if different_rm_vocab:
next_messages = data["messages"]
send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next)
else:
if training_step != 1:
Expand All @@ -653,13 +673,15 @@ def repeat_generator():
queries_next = queries_next.repeat(args.num_generation_per_prompt, 1)
send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next)
queries = queries_next
if different_rm_vocab:
messages = data["messages"]
messages = messages + messages

training_time_start = time.time()
with torch.no_grad():
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
Expand Down Expand Up @@ -688,13 +710,6 @@ def repeat_generator():
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
output = forward(generation_model, query_response, tokenizer.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del output, logits, all_logprob
torch.cuda.empty_cache()

ref_output = forward(ref_model, query_response, tokenizer.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
Expand All @@ -714,25 +729,44 @@ def repeat_generator():
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1
_, score, _ = get_reward(
reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length
)

if different_rm_vocab:
# we need to
# otherwise we get `<|endoftext|>[PAD][PAD][PAD][PAD]<|endoftext|>`
response_txts = tokenizer.batch_decode(postprocessed_response, skip_special_tokens=True)
reward_model_tokens = []
for j in range(i, i + args.local_rollout_forward_batch_size):
messages[j][-1]["content"] = response_txts[j - i]
reward_model_tokens.append(reward_model_tokenizer.apply_chat_template(messages[j]))

# right pad the reward model tokens
max_reward_model_len = max(len(item) for item in reward_model_tokens)
reward_model_tokens = [
item + [reward_model_tokenizer.pad_token_id] * (max_reward_model_len - len(item))
for item in reward_model_tokens
]
reward_model_tokens = torch.tensor(reward_model_tokens, device=device)
_, score, _ = get_reward(
reward_model, reward_model_tokens, reward_model_tokenizer.pad_token_id, 0
)
else:
_, score, _ = get_reward(
reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length
)

responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
global_scores = accelerator.gather(scores)
accelerator.print(f"global_scores: {global_scores}, {global_scores.mean()}")
del (logprob, ref_logprob, score)
del (ref_logprob, score)
gc.collect()
torch.cuda.empty_cache()

Expand All @@ -751,16 +785,8 @@ def repeat_generator():
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)

# 4. compute rewards
kl = logprobs - ref_logprobs
print(f"{accelerator.local_process_index=}, {kl.sum(1)=}")
non_score_reward = -args.beta * kl
non_score_reward_sum = non_score_reward.sum(1)
rlhf_reward = scores + non_score_reward_sum

# num_examples should be same as args.local_batch_size divided by 2
num_examples = scores.size(0) // 2
first_half = scores[:num_examples]
Expand All @@ -775,6 +801,8 @@ def repeat_generator():
)
scores_margin = scores[chosen_indices] - scores[rejected_indices]

logprobs = []
concat_indices = []
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
for epoch_idx in range(args.num_epochs):
b_inds = np.random.permutation(args.local_batch_size // args.num_generation_per_prompt)
Expand Down Expand Up @@ -852,6 +880,16 @@ def repeat_generator():
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
if epoch_idx == 0:
concat_indices.append(concat_mb_inds)
response = concat_query_responses[:, context_length:]
logits = concat_output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
logprob = torch.masked_fill(logprob, padding_mask[concat_mb_inds], INVALID_LOGPROB)
logprobs.append(logprob)
del all_logprob
chosen_rewards = args.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
rejected_rewards = args.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss
Expand Down Expand Up @@ -879,6 +917,14 @@ def repeat_generator():
# del everything and empty cache
torch.cuda.empty_cache()
with torch.no_grad():
logprobs = torch.cat(logprobs, 0)
concat_indices = torch.cat(concat_indices, 0)
restore_logprobs = torch.zeros_like(logprobs)
restore_logprobs[concat_indices] = logprobs
kl = restore_logprobs - ref_logprobs
non_score_reward = -args.beta * kl
non_score_reward_sum = non_score_reward.sum(1)
rlhf_reward = scores + non_score_reward_sum
local_metrics[0] = sequence_lengths.float().mean()
local_metrics[1] = (responses == args.stop_token_id).sum().float().mean()
local_metrics[2] = kl.sum(1).mean()
Expand Down
Loading