This document provides extensive examples demonstrating how to use oat 🌾 to (1) run various direct optimizers, (2) integrate different preference oracles, and (3) implement diverse active exploration algorithms. All the examples are tested on a machine with 8 A100 GPUs, with training logs publicly available on wandb for reproducibility.
First of all, you could always check all supported arguments by running:
python -m oat.experiment.main -h
oat
currently supports DPO, IPO, SLiC, and SimPO by setting --dap-algo
. Remember to adjust the associated hyper-parameter beta
.
python -m oat.experiment.main \
+ --dap-algo IPO \
+ --beta 0.1 \
# other flags...
In the main page we have shown the usage of pairrm
as the preference oracle, which runs in the same process as the actor. Next, we give an example of training online DPO
with a preference oracle served using Mosec.
First, we start the Mosec service locally, which will serve 4 Skywork/Skywork-Reward-Llama-3.1-8B
parallel workers as the preference oracle on the first 4 GPUs:
MOSEC_LOG_LEVEL=debug python -m oat.oracles.remote.server --cuda-devices 0,1,2,3
After the service is up (seeing "http service is running" from the log), start a new bash and run:
python -m oat.experiment.main \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
+ --preference-oracle remote \
+ --remote-rm-url http://0.0.0.0:8000 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
--use-wb \
--wb-run-name 1b_skywork_dpo_online
Alternatively, we could also query OpenAI API to use GPT-as-a-judge as the preference oracle:
python -m oat.experiment.main \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
+ --collocate \
--dap-algo DPO \
--beta 0.1 \
+ --preference-oracle gpt-4o-mini-2024-07-18 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
--use-wb \
--wb-run-name 1b_gpt_4o_mini_dpo_online
We enabled collocation of learner and actor workers given the abundant GPU memory, thanks to the fact that the preference oracle (GPT) runs on OpenAI's side and almost takes no resource on our machines.
Likewise, we can also host our own remote server for any reward model on a separate machine, utilizing more compute to train larger models. With a Kubernetes-managed cluster, you could follow these steps to serve a remote preference oracle at http://remote-rm
. Otherwise, you may need to obtain the remote machine's IP address (e.g., 10.0.0.1), and set --remote-rm-url http://10.0.0.1:8000
accordingly.
python -m oat.experiment.main \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
--preference-oracle remote \
+ --remote-rm-url http://remote-rm \
+ --pretrain trl-lib/pythia-6.9b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
+ --output-key pythia-6.9b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
--use-wb \
+ --wb-run-name 6.9b_skywork_dpo_online
All examples below assume a locally served preference oracle as done in the section above.
Note
Paper: https://arxiv.org/pdf/2411.01493.
You can find a thorough comparison between all algorithms mentioned in this section in our paper.
Oat natively supports SEA using the oat.experiment.main
entry script:
python -m oat.experiment.main \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
--preference-oracle remote \
--remote-rm-url http://0.0.0.0:8000 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
+ --num-prompt-epoch 2 \
+ --max-step-adjustment 0.75 \
+ --lr-warmup-ratio 0.02 \
+ --eval-query-interval 2560 \
+ --num-samples 20 \
+ --learn-rm \
+ --exp-method EnnBAITS \
+ --model-rollout \
+ --max-model-data-ratio 0.3 \
--use-wb \
--wb-run-name 1b_skywork_dpo_sea
Note
Run EE4LLM by disabling policy learning and enabling best-of-n sampling for evaluation:
python -m oat.experiment.main \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
--preference-oracle remote \
--remote-rm-url http://0.0.0.0:8000 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
+ --num-samples 20 \
+ --learn-rm \
+ --learn_rm_only \
+ --exp-method EnnEETS \
+ --exp_rnd_sample \
+ --online_evaluation \
+ --best_of_n_eval \
+ --num_bon 10 \
--use-wb \
+ --wb-run-name 1b_skywork_dpo_ee4llm
Note
APL can be implemented by inheriting oat's learner and actor classes (codes). Run it with a dedicated entry script:
+ python -m oat.experiment.run_apl \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
--preference-oracle remote \
--remote-rm-url http://0.0.0.0:8000 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
+ --num_prompt_epoch 4 \
+ --max_train 100000 \
+ --max_step_adjustment 0.125 \
+ --num_samples 8 \
+ --apl_pref_certainty_only \
--use-wb \
+ --wb-run-name 1b_skywork_apl
Note
XPO can be implemented by inheriting oat's learner and actor classes (codes). Run it with a dedicated entry script:
+ python -m oat.experiment.run_xpo \
--flash-attn \
--gradient-checkpointing \
--rnd-seed \
--gpus 8 \
--dap-algo DPO \
--beta 0.1 \
--preference-oracle remote \
--remote-rm-url http://0.0.0.0:8000 \
--pretrain trl-lib/pythia-1b-deduped-tldr-sft \
--prompt-data lkevinzc/tldr-with-sft-reference \
--input-key prompt \
--output-key pythia-1b-reference \
--sync-params-every 1 \
--max-train 50000 \
--generate-max-length 53 \
--train-batch-size 128 \
--rollout-batch-size 128 \
--rollout-batch-size-per-device 32 \
--pi-buffer-maxlen-per-device 32 \
--train-batch-size-per-device 8 \
--eval-steps 20 \
--use-wb \
+ --wb-run-name 1b_skywork_xpo