conda create -n sllada python=3.10
pip install -r requirements_sllada_a100.txtWe use FSDP to train (finetune) our shortcut-LLaDA model, the shortcut-LLaDA model is basically original LLaDA (8 Billion) + condition embedding & adaptive layer normalization (0.13 Billion).
export PYTHONPATH=$PWD:$PYTHONPATH
# solve autodl problem
HF_HOME="/root/autodl-tmp/hf_cache" TMPDIR="/root/autodl-tmp/python_tmp" HF_ENDPOINT="https://hf-mirror.com"Train the partly freezed model on a small dataset to make the new layer generate meaningful output, only use diffusion loss, no self-consistency.
accelerate launch \
--multi_gpu \
--num_processes=8 \
--num_machines=1 \
--mixed_precision "bf16" \
train/training.py \
--batch_size 4 \
--num_epochs 1 \
--sc_ratio 0.0 \
--learning_rate 1e-5 \
--dataset_name "wikitext" \
--dataset_config "wikitext-103-raw-v1" \
--freeze_pretrained_param \
--training_base_save_path "./ckpts" \
--init_from_lladaTrain the partly freezed model, user another larger dataset, split dataset for diffusion loss and self-consistency loss.
accelerate launch \
--multi_gpu \
--num_processes=8 \
--num_machines=1 \
--mixed_precision "bf16" \
train/training.py \
--batch_size 4 \
--num_epochs 2 \
--sc_ratio 0.75 \
--learning_rate 1e-5 \
--dataset_name "c4" \
--dataset_config "en" \
--freeze_pretrained_param \
--training_base_save_path "./ckpts" \
--finetuned_model_path "path_to_sllada_after_stage1"Unfreeze and train the full model, on another dataset, split dataset for diffusion loss and self-consistency loss.
accelerate launch \
--multi_gpu \
--num_processes=8 \
--num_machines=1 \
--mixed_precision "bf16" \
train/training.py \
--batch_size 4 \
--num_epochs 2 \
--sc_ratio 0.75 \
--learning_rate 1e-5 \
--dataset_name "c4" \
--dataset_config "en" \
--training_base_save_path "./ckpts" \
--finetuned_model_path "path_to_sllada_after_stage2"For stage1 or stage 2:
accelerate launch \
--multi_gpu \
--num_processes=8 \
--num_machines=1 \
--mixed_precision "bf16" \
train/training.py \
--batch_size 4 \
--num_epochs 1 \
--sc_ratio 0.0 \
--learning_rate 1e-5 \
--dataset_name "wikitext" \
--dataset_config "wikitext-103-raw-v1" \
--freeze_pretrained_param \
--from_checkpoint \
--training_base_save_path "./ckpts"For stage3:
accelerate launch \
--multi_gpu \
--num_processes=8 \
--num_machines=1 \
--mixed_precision "bf16" \
train/training.py \
--batch_size 4 \
--num_epochs 1 \
--sc_ratio 0.75 \
--learning_rate 1e-5 \
--dataset_name "wikitext" \
--dataset_config "wikitext-103-raw-v1" \
--from_checkpoint \
--training_base_save_path "./ckpts"