Skip to content

peacewang017/sLLaDA

Repository files navigation

Prerequisite

conda create -n sllada python=3.10
pip install -r requirements_sllada_a100.txt

Training

We use FSDP to train (finetune) our shortcut-LLaDA model, the shortcut-LLaDA model is basically original LLaDA (8 Billion) + condition embedding & adaptive layer normalization (0.13 Billion).

1 Pretraining

export PYTHONPATH=$PWD:$PYTHONPATH

# solve autodl problem
HF_HOME="/root/autodl-tmp/hf_cache" TMPDIR="/root/autodl-tmp/python_tmp" HF_ENDPOINT="https://hf-mirror.com"

1.1 Stage 1

Train the partly freezed model on a small dataset to make the new layer generate meaningful output, only use diffusion loss, no self-consistency.

accelerate launch \
    --multi_gpu \
    --num_processes=8 \
    --num_machines=1 \
    --mixed_precision "bf16" \
    train/training.py \
    --batch_size 4 \
    --num_epochs 1 \
    --sc_ratio 0.0 \
    --learning_rate 1e-5 \
    --dataset_name "wikitext" \
    --dataset_config "wikitext-103-raw-v1" \
    --freeze_pretrained_param \
    --training_base_save_path "./ckpts" \
    --init_from_llada

1.2 Stage 2

Train the partly freezed model, user another larger dataset, split dataset for diffusion loss and self-consistency loss.

accelerate launch \
    --multi_gpu \
    --num_processes=8 \
    --num_machines=1 \
    --mixed_precision "bf16" \
    train/training.py \
    --batch_size 4 \
    --num_epochs 2 \
    --sc_ratio 0.75 \
    --learning_rate 1e-5 \
    --dataset_name "c4" \
    --dataset_config "en" \
    --freeze_pretrained_param \
    --training_base_save_path "./ckpts" \
    --finetuned_model_path "path_to_sllada_after_stage1"

1.3 Stage 3

Unfreeze and train the full model, on another dataset, split dataset for diffusion loss and self-consistency loss.

accelerate launch \
    --multi_gpu \
    --num_processes=8 \
    --num_machines=1 \
    --mixed_precision "bf16" \
    train/training.py \
    --batch_size 4 \
    --num_epochs 2 \
    --sc_ratio 0.75 \
    --learning_rate 1e-5 \
    --dataset_name "c4" \
    --dataset_config "en" \
    --training_base_save_path "./ckpts" \
    --finetuned_model_path "path_to_sllada_after_stage2"

2 Continue training from checkpoint

For stage1 or stage 2:

accelerate launch \
    --multi_gpu \
    --num_processes=8 \
    --num_machines=1 \
    --mixed_precision "bf16" \
    train/training.py \
    --batch_size 4 \
    --num_epochs 1 \
    --sc_ratio 0.0 \
    --learning_rate 1e-5 \
    --dataset_name "wikitext" \
    --dataset_config "wikitext-103-raw-v1" \
    --freeze_pretrained_param \
    --from_checkpoint \
    --training_base_save_path "./ckpts"

For stage3:

accelerate launch \
    --multi_gpu \
    --num_processes=8 \
    --num_machines=1 \
    --mixed_precision "bf16" \
    train/training.py \
    --batch_size 4 \
    --num_epochs 1 \
    --sc_ratio 0.75 \
    --learning_rate 1e-5 \
    --dataset_name "wikitext" \
    --dataset_config "wikitext-103-raw-v1" \
    --from_checkpoint \
    --training_base_save_path "./ckpts"

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published