Skip to content

Conversation

@rakkit
Copy link
Contributor

@rakkit rakkit commented Dec 12, 2025

TL;DR

Adds SFT training to Torchtitan plus a small greedy_packing addition.

Most of code borrowed from Verl and OpenRLHF

Changes

  • Added SFT dataset config to job config
  • Updated attention + modifying get_attention_masks to support SFT masks (only landed on Llama3 now)
  • Temporarily use HFTokenizer, need to fix this later
  • Added SFT dataset/dataloader that returns input_ids, labels (user tokens masked), attention_masks, position_ids

TODO

Run

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh \
  --training.running_sft_training \
  --model.flavor=debugmodel_varlen_attn \
  --training.dataset_path=openai/gsm8k \
  --sft_data_config.dataset_subset=main
image

more test

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \ --model.flavor=8B_varlen \ --training.dataset_path=openai/gsm8k \
 --sft_data_config.dataset_subset=main \
 --model.hf_assets_path="$Home/Meta-Llama-3.1-8B-Instruct/" \
 --training.local_batch_size=8 \ --activation_checkpoint.mode=full \
 --training.seq_len=2048 \ --training.steps=100 \
 --lr_scheduler.warmup_steps=10 \  --debug.seed 10 \
 --sft_data_config.pad_mode=right \
 --metrics.enable_wandb \

(torch 2.10.0.dev20251124+cu129 and i am using cudnn attention )
image

W B Chart 12_12_2025, 8_32_32 PM

compile does not work for no-padding because the seq-len for each training step keeps changing. We could pad the buffer to seqlen when turning on greedy_packing. (its packing_on++) to make compile happy.

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 12, 2025
@rakkit
Copy link
Contributor Author

rakkit commented Dec 12, 2025

Just confirm it works on the SFT dataset in "multiturn" format, with/ and w/o the tool, and on reasoning data.

(Be aware that when it turns on apply_chat_template, we are supposed to provide the chat_template.jinja in tokenzier's folder. There is no such file for "Meta-Llama-3.1-8B-Instruct". For test purposes, you can use, e.g,. tokenizer from Olmo-3-7B-Instruct


TO reproduce it on the multiturn format dataset Dolci-Instruct-SFT, without apply_chat_template

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME//torchtitan_assets/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/sft/Dolci-Instruct-SFT/ \
 --sft_data_config.is_multiturn \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
image

TO reproduce it on the multiturn format dataset Dolci-Instruct-SFT, with apply_chat_template

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME//torchtitan_assets/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/sft/Dolci-Instruct-SFT/ \
 --sft_data_config.is_multiturn \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
--sft_data_config.apply_chat_template \
--sft_data_config.ignore_input_ids_mismatch \

image

To reproduce it on tool datasets e.g, ReTool-SFT-multi-turn dataset

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/ReTool-SFT-multi-turn/ \
 --sft_data_config.is_multiturn --sft_data_config.apply_chat_template \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
 --sft_data_config.ignore_input_ids_mismatch \
  --sft_data_config.tools_key=tools \
image

To reproduce it on the reasoning dataset e.g., nvidia/Puzzle-KD-Nemotron-Post-Training-Dataset-v2/

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/Puzzle-KD-Nemotron-Post-Training-Dataset-v2 \
 --sft_data_config.is_multiturn --sft_data_config.apply_chat_template \
 --sft_data_config.split=validation \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
 --sft_data_config.thinking_key=reasoning \
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant