Skip to content

Commit

Permalink
Fix dataset mixing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Nov 5, 2024
1 parent 170972b commit f65501b
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 44 deletions.
16 changes: 8 additions & 8 deletions docs/algorithms/online_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ python open_instruct/online_dpo_vllm_thread.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_stop_penalty \
Expand All @@ -43,7 +43,7 @@ python open_instruct/online_dpo_vllm_thread.py \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 64 \
--max_token_length 2048 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--beta 0.1 \
--output_dir models/rm/rm_sentiment_1b \
Expand All @@ -64,7 +64,7 @@ python open_instruct/online_dpo_vllm_thread.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_stop_penalty \
Expand All @@ -76,7 +76,7 @@ python open_instruct/online_dpo_vllm_thread.py \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 64 \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--beta 0.1 \
--output_dir models/rm/rm_sentiment_1b \
Expand Down Expand Up @@ -112,7 +112,7 @@ python mason.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 3e-6 \
--output_dir models/minimal/online_dpo_vllm_thread_tldr \
--per_device_train_batch_size 16 \
Expand Down Expand Up @@ -158,7 +158,7 @@ python mason.py \
--dataset_eval_mixer '{"HuggingFaceH4/no_robots": 1.0}' \
--dataset_eval_splits test \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 8e-7 \
--output_dir /output/ \
--chat_template tulu \
Expand Down Expand Up @@ -211,7 +211,7 @@ python mason.py \
--dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \
--dataset_eval_splits test_prefs \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 8e-7 \
--output_dir /output/ \
--chat_template tulu \
Expand Down Expand Up @@ -265,7 +265,7 @@ python mason.py \
--dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \
--dataset_eval_splits test_prefs \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 8e-7 \
--output_dir /output/ \
--chat_template tulu \
Expand Down
14 changes: 7 additions & 7 deletions docs/algorithms/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ python open_instruct/ppo_vllm_thread.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_stop_penalty \
Expand All @@ -43,7 +43,7 @@ python open_instruct/ppo_vllm_thread.py \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 64 \
--max_token_length 2048 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--beta 0.1 \
--output_dir models/rm/rm_sentiment_1b \
Expand All @@ -64,7 +64,7 @@ python open_instruct/ppo_vllm_thread.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_stop_penalty \
Expand All @@ -76,7 +76,7 @@ python open_instruct/ppo_vllm_thread.py \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 64 \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--beta 0.1 \
--output_dir models/rm/rm_sentiment_1b \
Expand Down Expand Up @@ -112,7 +112,7 @@ python mason.py \
--dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \
--dataset_eval_splits validation \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo_vllm_thread_tldr \
--per_device_train_batch_size 16 \
Expand Down Expand Up @@ -158,7 +158,7 @@ python mason.py \
--dataset_eval_mixer '{"HuggingFaceH4/no_robots": 1.0}' \
--dataset_eval_splits test \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 8e-7 \
--output_dir /output/ \
--chat_template tulu \
Expand Down Expand Up @@ -211,7 +211,7 @@ python mason.py \
--dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \
--dataset_eval_splits test_prefs \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--learning_rate 8e-7 \
--output_dir /output/ \
--chat_template tulu \
Expand Down
14 changes: 7 additions & 7 deletions docs/algorithms/reward_modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ python -i open_instruct/reward_modeling.py \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 32 \
--max_token_length 1024 \
--max_prompt_token_lenth 1024 \
--max_prompt_token_length 1024 \
--num_train_epochs 1 \
--output_dir models/rm/rm \
--sanity_check \
Expand Down Expand Up @@ -71,7 +71,7 @@ python mason.py \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 4 \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--output_dir models/rm/rm_sentiment_1b \
--with_tracking \
Expand Down Expand Up @@ -103,7 +103,7 @@ python mason.py \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 8 \
--max_token_length 1024 \
--max_prompt_token_lenth 512 \
--max_prompt_token_length 512 \
--num_train_epochs 1 \
--output_dir models/rm/rm_sentiment_1b \
--with_tracking \
Expand Down Expand Up @@ -134,7 +134,7 @@ python mason.py \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
--max_token_length 2048 \
--max_prompt_token_lenth 1024 \
--max_prompt_token_length 1024 \
--num_train_epochs 1 \
--output_dir models/rm/rm_hh_1b \
--with_tracking \
Expand Down Expand Up @@ -165,7 +165,7 @@ python mason.py \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
--max_token_length 2048 \
--max_prompt_token_lenth 1024 \
--max_prompt_token_length 1024 \
--num_train_epochs 1 \
--output_dir models/rm/rm_hh_1b \
--with_tracking \
Expand Down Expand Up @@ -198,7 +198,7 @@ python mason.py \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 32 \
--max_token_length 1024 \
--max_prompt_token_lenth 1024 \
--max_prompt_token_length 1024 \
--num_train_epochs 1 \
--output_dir models/rm/rm_tulu_8b \
--gradient_checkpointing \
Expand Down Expand Up @@ -391,7 +391,7 @@ dataset_config = DatasetConfig(
dataset_name="trl-internal-testing/sentiment-trl-style",
chat_template="simple_chat",
max_token_length=1024,
max_prompt_token_lenth=1024,
max_prompt_token_length=1024,
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.chat_template = CHAT_TEMPLATES["simple_chat"]
Expand Down
10 changes: 5 additions & 5 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class DatasetConfig:

# filter config
max_token_length: Optional[int] = None
max_prompt_token_lenth: Optional[int] = None
max_prompt_token_length: Optional[int] = None

# dataset.map config
sanity_check: bool = False
Expand Down Expand Up @@ -314,8 +314,8 @@ def tokenize_fn(row):
def filter(self, dataset: Union[Dataset, DatasetDict]):
def filter_fn(row):
return (
len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth
if self.config.max_prompt_token_lenth is not None
len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length
if self.config.max_prompt_token_length is not None
else (
True and len(row[INPUT_IDS_CHOSEN_KEY]) <= self.config.max_token_length
if self.config.max_token_length is not None
Expand Down Expand Up @@ -388,8 +388,8 @@ def tokenize_fn(row):
def filter(self, dataset: Dataset):
def filter_fn(row):
max_prompt_token_length_ok = True
if self.config.max_prompt_token_lenth is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth
if self.config.max_prompt_token_length is not None:
max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length

max_token_length_ok = True
if self.config.max_token_length is not None:
Expand Down
8 changes: 7 additions & 1 deletion open_instruct/online_dpo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
# create the dataset
dataset_dict = DatasetDict()
dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config)
if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1:
args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets.")
if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1:
args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets.")
train_dataset = combine_dataset(
args.dataset_mixer_dict,
splits=args.dataset_train_splits,
Expand Down Expand Up @@ -571,7 +577,7 @@ def repeat_generator():
args=(
model_config.model_name_or_path,
model_config.model_revision,
dataset_config.max_prompt_token_lenth + args.response_length,
dataset_config.max_prompt_token_length + args.response_length,
args.vllm_device,
args.vllm_gpu_memory_utilization,
generation_config,
Expand Down
8 changes: 7 additions & 1 deletion open_instruct/ppo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
# create the dataset
dataset_dict = DatasetDict()
dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config)
if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1:
args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets.")
if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1:
args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets.")
train_dataset = combine_dataset(
args.dataset_mixer_dict,
splits=args.dataset_train_splits,
Expand Down Expand Up @@ -645,7 +651,7 @@ def repeat_generator():
args=(
model_config.model_name_or_path,
model_config.model_revision,
dataset_config.max_prompt_token_lenth + args.response_length,
dataset_config.max_prompt_token_length + args.response_length,
args.vllm_device,
args.vllm_gpu_memory_utilization,
generation_config,
Expand Down
6 changes: 6 additions & 0 deletions open_instruct/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig):
# create the dataset
dataset_dict = DatasetDict()
dataset_processor = PreferenceDatasetProcessor(tokenizer=tokenizer, config=dataset_config)
if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1:
args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets.")
if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1:
args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict)
print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets.")
train_dataset = combine_dataset(
args.dataset_mixer_dict,
splits=args.dataset_train_splits,
Expand Down
1 change: 1 addition & 0 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def combine_dataset(
Whether to keep ids for training that are added during mixing.
Used primarily in mix_data.py for saving, or the saved dataset has IDs already.
"""
assert len(splits) == len(dataset_mixer), "Number of splits must match the number of datasets."
if isinstance(dataset_mixer, list):
assert len(dataset_mixer) % 2 == 0, f"Data mixer list length is not even: {dataset_mixer}"
mixer_dict = {}
Expand Down
Loading

0 comments on commit f65501b

Please sign in to comment.