Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add moe pretrain #49

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pretrain/installers/v3-moe-recipes-sakura/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# moe-recipes installation script for Sakura/LLM-jp-3 models

LLM-jp-3でのMoE実験を行うためのmoe-recipes環境をSakuraクラスタにインストールするためのスクリプトです。
System Pythonやpyenvに依存しない閉じた環境を指定したディレクトリ上に構築します。

## Usage

### Build

インストール処理のためにSakuraクラスタのCPUノードを1個使用します。
時間がかかるので気長に待って下さい。

```shell
git clone https://github.com/llm-jp/scripts
cd pretrain/installers/v3-moe-recipes-sakura

# ~/myspace に環境をインストールします。
sbatch install.sh ~/myspace
```

### Check

インストール終了後、下記のようなディレクトリ構造が構築されています。

```
~/myspace/
example/ サンプルスクリプト
installer_envvar.log インストール開始後に記録した環境変数の一覧
install.sh 使用したインストールスクリプト
python/ Python実行環境
requirements.txt venvに事前インストールされたライブラリ一覧
scripts/ 各種の環境設定用スクリプト
src/ 個別ダウンロードされたライブラリ
venv/ Python仮想環境 (python/ にリンク)
```

インストールした環境で正常に事前学習ジョブを起動できるかどうかを確認します。

```shell
cd ~/myspace

# デフォルトでは1ノードを専有し、GPUを8枚全て使うジョブが起動します。
sbatch example/sbatch.sh

# W&Bにtrain lossが記録されるのを確認したらジョブを止めてください。
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import random

import numpy as np
import torch
from transformers import AutoConfig, AutoModelForCausalLM


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def save_model_and_config(model_name, save_directory):

set_seed(1234)

# Check if the directory already exists
if os.path.exists(save_directory):
print(f"Directory {save_directory} already exists. Skipping model saving.")
return

# Load config
config = AutoConfig.from_pretrained(model_name)
print(f"Config loaded from {model_name}")

# Create model from config
model = AutoModelForCausalLM.from_config(config)
print("Model created from config")

# Create save directory
os.makedirs(save_directory)

# Save model and config
model.save_pretrained(save_directory)
config.save_pretrained(save_directory)

print(f"Model and config have been saved to {save_directory}")


if __name__ == "__main__":
model_name = "example/config.json"
save_directory = "Mixtral-llm-jp-v3-8x1.8B-initial-checkpoint"

save_model_and_config(model_name, save_directory)
31 changes: 31 additions & 0 deletions pretrain/installers/v3-moe-recipes-sakura/example/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"_name_or_path": "",
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 7168,
"max_position_embeddings": 4096,
"model_type": "mixtral",
"num_attention_heads": 16,
"num_experts_per_tok": 2,
"num_hidden_layers": 24,
"num_key_value_heads": 16,
"num_local_experts": 8,
"output_router_logits": true,
"rms_norm_eps": 1e-05,
"rope_theta": 10000.0,
"router_aux_loss_coef": 0.01,
"router_jitter_noise": 0.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.46.0.dev0",
"use_cache": false,
"vocab_size": 99574
}
139 changes: 139 additions & 0 deletions pretrain/installers/v3-moe-recipes-sakura/example/sbatch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/bin/bash
#
# Example sbatch launcher script of pretraining tasks.
#
# This script only constructs cluster-related environment variables, and immediately
# calls mpirun with train.sh, which implements an actual invocation of the Megatron-LM
# trainer script.
#
# This script is installed together with other tools so that you can check if the
# installed environment works as expected by launching the job using this script.
#
# Usage:
# 1. cd {root directory that you installed training scripts}
# 2. sbatch example/sbatch.sh

#SBATCH --job-name=pretrain-test
#SBATCH --partition=gpu-debug
#SBATCH --nodes=1
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8
#SBATCH --output=%x-%j.out
#SBATCH --error=%x-%j.err

set -eu -o pipefail

source scripts/environment.sh
source venv/bin/activate

# CUTLASS
CUTLASS_HOME=src/cutlass/build
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUTLASS_HOME}/lib

export MASTER_ADDR="$(scontrol show hostname $SLURM_JOB_NODELIST | head -n1)"
export MASTER_PORT=12800

echo "MASTER_ADDR=${MASTER_ADDR}"

NUM_NODES=$SLURM_JOB_NUM_NODES
NUM_GPUS_PER_NODE=$(echo $SLURM_TASKS_PER_NODE | cut -d '(' -f 1)
NUM_GPUS=$((${NUM_NODES} * ${NUM_GPUS_PER_NODE}))

echo NUM_NODES=$NUM_NODES
echo NUM_GPUS_PER_NODE=$NUM_GPUS_PER_NODE
echo NUM_GPUS=$NUM_GPUS

MICRO_BATCH_SIZE=4
GLOBAL_BATCH_SIZE=1024
GRADIENTS_ACCUMULATION_STEPS=$((GLOBAL_BATCH_SIZE / MICRO_BATCH_SIZE / NUM_GPUS))

if [ $GRADIENTS_ACCUMULATION_STEPS -lt 1 ]; then
echo "Global batch size is too small for the number of GPUs"
exit 1
fi

WEIGHT_DECAY=0.1
GRAD_CLIP=1

# deepspeed config
DEEPSPEED_CONFIG="src/moe-recipes/deepspeed_config.json"

BF16_ENABLED=true
DEEPSPEED_ZERO_STAGE=3

OVERLAP_COMMUNICATION=true
CONTINOUS_GRADIENTS=true

DEEPSPEED_SUB_GROUP_SIZE=1e12
DEEPSPEED_REDUCE_BUCKET_SIZE=1e9
DEEPSPEED_STAGE3_PREFETCH_BUCKET_SIZE=5e8
DEEPSPEED_STAGE3_PARAM_PERSISTENCE_THRESHOLD=1e6

DEEPSPEED_STAGE3_MAX_LIVE_PARAMETERS=1e9
DEEPSPEED_STAGE3_MAX_REUSE_DISTANCE=1e9

WALL_CLOCK_BREAKDOWN=false

DEEPSPEED_CONGIG_CONTENT=$(
cat <<EOF
{
"bf16": {
"enabled": $BF16_ENABLED
},
"data_types": {
"grad_accum_dtype": "fp32"
},
"zero_optimization": {
"stage": $DEEPSPEED_ZERO_STAGE,
"overlap_comm": $OVERLAP_COMMUNICATION,
"contiguous_gradients": $CONTINOUS_GRADIENTS,
"sub_group_size": $DEEPSPEED_SUB_GROUP_SIZE,
"reduce_bucket_size": $DEEPSPEED_REDUCE_BUCKET_SIZE,
"stage3_prefetch_bucket_size": $DEEPSPEED_STAGE3_PREFETCH_BUCKET_SIZE,
"stage3_param_persistence_threshold": $DEEPSPEED_STAGE3_PARAM_PERSISTENCE_THRESHOLD,
"stage3_max_live_parameters": $DEEPSPEED_STAGE3_MAX_LIVE_PARAMETERS,
"stage3_max_reuse_distance": $DEEPSPEED_STAGE3_MAX_REUSE_DISTANCE
},
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"gradient_accumulation_steps": $GRADIENTS_ACCUMULATION_STEPS,
"gradient_clipping": $GRAD_CLIP,
"wall_clock_breakdown": $WALL_CLOCK_BREAKDOWN
}
EOF
)

# write deepspeed config file
echo "$DEEPSPEED_CONGIG_CONTENT" >$DEEPSPEED_CONFIG

# Initialization
python example/checkpoint_init.py

PYTHONPATH=${PYTHONPATH:-}

if [ -z "$PYTHONPATH" ]; then
export PYTHONPATH="./src/moe-recipes:./src/moe-recipes/src"
else
export PYTHONPATH="./src/moe-recipes:./src/moe-recipes/src:${PYTHONPATH}"
fi

echo "PYTHONPATH is now: $PYTHONPATH"

mpirun \
-np $NUM_GPUS \
--npernode $NUM_GPUS_PER_NODE \
-bind-to none \
-map-by slot \
-x TORCH_NCCL_ASYNC_ERROR_HANDLING=1 \
-x LD_LIBRARY_PATH \
-x PATH \
-x MASTER_ADDR=$MASTER_ADDR \
-x MASTER_PORT=$MASTER_PORT \
-x NUM_NODES=$NUM_NODES \
-x NUM_GPUS_PER_NODE=$NUM_GPUS_PER_NODE \
-x GRADIENTS_ACCUMULATION_STEPS=$GRADIENTS_ACCUMULATION_STEPS \
-x MICRO_BATCH_SIZE=$MICRO_BATCH_SIZE \
-x GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE \
-x DEEPSPEED_CONFIG=$DEEPSPEED_CONFIG \
-x DEEPSPEED_ZERO_STAGE=$DEEPSPEED_ZERO_STAGE \
-x PYTHONPATH=$PYTHONPATH \
bash example/train.sh
108 changes: 108 additions & 0 deletions pretrain/installers/v3-moe-recipes-sakura/example/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/bin/bash
# Node-level moe-recipes launcher
#
# Environment variables that the script expects to be passed from mpirun:
# * TORCH_NCCL_ASYNC_ERROR_HANDLING: Enable/disable async error handling for NCCL (1 for enable, 0 for disable)
# * LD_LIBRARY_PATH: Library path for dynamic linking
# * PATH: System path for executable files
# * MASTER_ADDR: Address of the master node
# * MASTER_PORT: Port number of the master node
# * NUM_NODES: Number of nodes assigned for this task
# * NUM_GPUS_PER_NODE: Number of GPUs in the node assined for this task
# * GRADIENTS_ACCUMULATION_STEPS: Number of gradient accumulation steps
# * MICRO_BATCH_SIZE: Micro batch size for training
# * GLOBAL_BATCH_SIZE: Global batch size for training
# * DEEPSPEED_CONFIG: Path to DeepSpeed configuration file
# * DEEPSPEED_ZERO_STAGE: DeepSpeed ZeRO stage
# * PYTHONPATH: Python module search path

set -eu -o pipefail

source scripts/environment.sh
source scripts/mpi_variables.sh
source venv/bin/activate

export LOGLEVEL=INFO
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=WARN
export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export CUDNN_LOGDEST_DBG=stderr
export CUDNN_LOGERR_DBG=1

NUM_GPUS=$((${NUM_NODES} * ${NUM_GPUS_PER_NODE}))

# training config
SEQ_LENGTH=4096
SLIDING_WINDOW_SIZE=4096
DATA_PARALLEL_SIZE=$NUM_GPUS

TRAIN_STEPS=5000

# optimizer config
LR=3e-4
MIN_LR=3e-5
LR_WARMUP_STEPS=1000
LR_DECAY_STEPS=5000
WEIGHT_DECAY=0.1
GRAD_CLIP=1

ADAMW_BETA1=0.9
ADAMW_BETA2=0.95
ADAMW_EPS=1E-8

# model config
TOKENIZER_MODEL=src/llm-jp-tokenizer/models/ver3.0/llm-jp-tokenizer-100k.ver3.0b1.model

CHECKPOINT_DIR=Mixtral-llm-jp-v3-8x1.8B-initial-checkpoint
CHECKPOINT_SAVE_DIR=checkpoints

mkdir -p ${CHECKPOINT_SAVE_DIR}

# data config
DATASET_DIR=/data/llm-jp-corpus/v3.0.0/training_resharded_tokenize_ver3.0
DATA_PATH=""
DATA_PATH="${DATA_PATH} 2563804308 ${DATASET_DIR}/train/ja/wiki_0000.jsonl_text_document"

# job name
JOB_NAME="test-$(whoami)"

python src/moe-recipes/examples/finetuning.py \
--seq-length ${SEQ_LENGTH} \
--sliding-window-size ${SLIDING_WINDOW_SIZE} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--train-iters ${TRAIN_STEPS} \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--data-path ${DATA_PATH} \
--split 99999,1,0 \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--lr-warmup-iters ${LR_WARMUP_STEPS} \
--lr-decay-iters ${LR_DECAY_STEPS} \
--weight-decay ${WEIGHT_DECAY} \
--grad-clip-norm ${GRAD_CLIP} \
--optimizer adam \
--adam-beta1 $ADAMW_BETA1 \
--adam-beta2 $ADAMW_BETA2 \
--adam-eps $ADAMW_EPS \
--save-interval 10 \
--eval-interval 1000000000 \
--eval-iters 1 \
--bf16 \
--mixed-precision \
--base-model ${CHECKPOINT_DIR} \
--save ${CHECKPOINT_SAVE_DIR} \
--load ${CHECKPOINT_SAVE_DIR} \
--use-zero \
--zero-config ${DEEPSPEED_CONFIG} \
--zero-stage ${DEEPSPEED_ZERO_STAGE} \
--no-meta-device \
--output-router-logits \
--use-mpi \
--continual-pretraining \
--wandb-entity "llm-jp" \
--wandb-project "sakura-test-moe" \
--wandb-name "${JOB_NAME}"
Loading