forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-vicuna.yaml
122 lines (109 loc) · 3.79 KB
/
train-vicuna.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
resources:
accelerators: A100-80GB:8
disk_size: 1000
use_spot: true
num_nodes: 1
file_mounts:
/artifacts:
name: skypilot-chatbot # Change to your own bucket
store: gcs
mode: MOUNT
/data:
name: model-weights # Change to your own bucket
store: gcs
mode: MOUNT
# /llamma:
# name: llama-ckpts # Change to the bucket that contains the LLaMA weights
# store: gcs
# mode: MOUNT
workdir: .
setup: |
# Setup the environment
conda create -n chatbot python=3.10 -y
conda activate chatbot
# Install pytorch
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
# Install huggingface with the LLaMA commit
cd ~
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout cae78c46 # pin to latest commit
pip install .
cd ~/sky_workdir
# Install fastchat
pip install -e .
pip install flash-attn
mkdir -p /artifacts/llama-hf/llama-${MODEL_SIZE}B
if [ ! -f /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete ]; then
mkdir -p ~/llama-${MODEL_SIZE}b
gsutil -m rsync -r /llama/${MODEL_SIZE}b/ ~/llama-${MODEL_SIZE}b
cd ~/transformers
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir $HOME/llama-${MODEL_SIZE}b \
--model_size ${MODEL_SIZE}B \
--output_dir ~/hf-output || exit 1
mv ~/hf-output/tokenizer/* ~/hf-output/llama-${MODEL_SIZE}b
gsutil -m rsync -r ~/hf-output/llama-${MODEL_SIZE}b/ /artifacts/llama-hf/llama-${MODEL_SIZE}B
touch /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete
else
mkdir -p ~/hf-output/llama-${MODEL_SIZE}b
gsutil -m cp -r /artifacts/llama-hf/llama-${MODEL_SIZE}B/* ~/hf-output/llama-${MODEL_SIZE}b
fi
run: |
conda activate chatbot
SEQ_LEN=${SEQ_LEN:-512}
GC_SCALE=${GC_SCALE:-1}
DATE=${DATE:-20230303}
USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
if [ $USE_FLASH_ATTN -eq 1 ]; then
TRAIN_SCRIPT=fastchat/train/train_mem.py
USE_FLASH_SUFFIX="-flash"
else
TRAIN_SCRIPT=fastchat/train/train.py
USE_FLASH_SUFFIX=""
fi
echo "Training with seq_len=${SEQ_LEN} and gc_scale=${GC_SCALE}"
PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Hack copy it once to make it faster later
mkdir -p ~/.checkpoints
CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/sharegpt-${DATE}-seq-${SEQ_LEN}-${USE_FLASH_SUFFIX}
last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -n | tail -1)
gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path ~/hf-output/llama-${MODEL_SIZE}b \
--data_path /data/sharegpt/sharegpt_20230322_clean_lang_split.json \
--bf16 True \
--output_dir $CKPT_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1200 \
--save_total_limit 100 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--gradient_checkpointing True \
--lazy_preprocess True
envs:
MODEL_SIZE: 13
SEQ_LEN: 2048
GC_SCALE: 4
DATE: 20230322
USE_FLASH_ATTN: 1