Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add installer of model pre-taing with Megatron-LM for llm-jp cluster #57

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.err
*.out
43 changes: 43 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Megatron installation script for mdx/LLM-jp v3 models

LLM-jp v3モデルを作成するためのMegatron環境をllm-jpクラスタにインストールするためのスクリプトです。
System Pythonやpyenvに依存しない閉じた環境を指定したディレクトリ上に構築します。

## Usage

### Build

インストール処理のためにCPUリソースを使用します。
時間がかかるので気長に待って下さい。

```shell
git clone https://github.com/llm-jp/scripts
cd scripts/pretrain/installers/v3-megatron-mdx

# ~/myspace に環境をインストールします。
bash install.sh ~/myspace
```

### Check

インストール終了後、下記のようなディレクトリ構造が構築されています。

```
~/myspace/
example/ サンプルスクリプト
installer_envvar.log インストール開始後に記録した環境変数の一覧
install.sh 使用したインストールスクリプト
python/ Python実行環境
requirements.txt venvに事前インストールされたライブラリ一覧
scripts/ 各種の環境設定用スクリプト
src/ 個別ダウンロードされたライブラリ
venv/ Python仮想環境
```

インストールした環境で正常に事前学習ジョブを起動できるかどうかを確認します。

```shell
cd ~/myspace
bash example/mpi_wrapper.sh
```
W&Bにtrain lossが記録されるのを確認したらジョブを止めてください。
44 changes: 44 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/example/mpi_wrapper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
#
# Example launcher script of pretraining tasks.
#
# This script only constructs cluster-related environment variables, and immediately
# calls mpirun with train.sh, which implements an actual invocation of the Megatron-LM
# trainer script.
#
# This script is installed together with other tools so that you can check if the
# installed environment works as expected by launching the job using this script.
#
# Usage:
# 1. cd {root directory that you installed training scripts}
# 2. bash example/mpi_wrapper.sh

set -eu -o pipefail

source scripts/environment.sh
source venv/bin/activate

export MASTER_ADDR="$(ip -br addr | sed -n 2p | awk '{print $3}' | cut -d'/' -f1)"
export MASTER_PORT=12800

echo "MASTER_ADDR=${MASTER_ADDR}"

NUM_NODES=1
NUM_GPUS_PER_NODE=$(nvidia-smi -L | wc -l)
NUM_GPUS=$((${NUM_NODES} * ${NUM_GPUS_PER_NODE}))

echo NUM_NODES=$NUM_NODES
echo NUM_GPUS_PER_NODE=$NUM_GPUS_PER_NODE
echo NUM_GPUS=$NUM_GPUS

mpirun \
-np $NUM_GPUS \
--npernode $NUM_GPUS_PER_NODE \
-bind-to none \
-map-by slot \
-x MASTER_ADDR=$MASTER_ADDR \
-x MASTER_PORT=$MASTER_PORT \
-x NUM_NODES=$NUM_NODES \
-x NUM_GPUS_PER_NODE=$NUM_GPUS_PER_NODE \
bash example/train.sh

129 changes: 129 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/example/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/bin/bash
# Node-level Megatron-LM launcher
#
# Environment variables that the script expects to be passed from mpirun:
# * MASTER_ADDR: Address of the master node
# * MASTER_PORT: Port number of the master node
# * NUM_NODES: Number of nodes assigned for this task
# * NUM_GPUS_PER_NODE: Number of GPUs in the node assined for this task

set -eu -o pipefail

source scripts/environment.sh
source scripts/mpi_variables.sh
source venv/bin/activate

export LOGLEVEL=INFO
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=WARN
export PYTHONFAULTHANDLER=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_LAUNCH_BLOCKING=0
export CUDNN_LOGDEST_DBG=stderr
export CUDNN_LOGERR_DBG=1

NUM_GPUS=$((${NUM_NODES} * ${NUM_GPUS_PER_NODE}))

# model config
HIDDEN_SIZE=1024
FFN_HIDDEN_SIZE=4096
NUM_LAYERS=12
NUM_HEADS=8
SEQ_LENGTH=2048

# distributed settings
TENSOR_PARALLEL_SIZE=1
PIPELINE_PARALLEL_SIZE=1
CONTEXT_PARALLEL_SIZE=1
DATA_PARALLEL_SIZE=$((${NUM_GPUS} / (${TENSOR_PARALLEL_SIZE} * ${PIPELINE_PARALLEL_SIZE})))

# training config
MICRO_BATCH_SIZE=8
GLOBAL_BATCH_SIZE=1024

LR=1e-4
MIN_LR=1e-5
WEIGHT_DECAY=0.1
GRAD_CLIP=1

LR_WARMUP_STEPS=1000
LR_DECAY_ITERS=5000
TRAIN_STEPS=$((${LR_WARMUP_STEPS} + ${LR_DECAY_ITERS}))

# model config
TOKENIZER_MODEL=src/llm-jp-tokenizer/models/ver3.0/tokenize/v3.0b1
CHECKPOINT_SAVE_DIR=checkpoints

mkdir -p ${CHECKPOINT_SAVE_DIR}

# data config
DATASET_DIR=/model/llm-jp-corpus/v3.1.0/training_resharded_tokenize_ver3.0
DATA_PATH=""
DATA_PATH="${DATA_PATH} 15074855804 ${DATASET_DIR}/train2/ja/warp-pdf-e00_0000.jsonl_text_document"

# job name
JOB_NAME="test-$(whoami)"

# checkpoint load
# first training
CHECKPOINT_ARGS="" #"--load ${CHECKPOINT_SAVE_DIR} --no-load-rng --no-load-optim"

python src/Megatron-LM/pretrain_gpt.py \
--tensor-model-parallel-size ${TENSOR_PARALLEL_SIZE} \
--pipeline-model-parallel-size ${PIPELINE_PARALLEL_SIZE} \
--context-parallel-size ${CONTEXT_PARALLEL_SIZE} \
--sequence-parallel \
--use-distributed-optimizer \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
--ffn-hidden-size ${FFN_HIDDEN_SIZE} \
--num-attention-heads ${NUM_HEADS} \
--seq-length ${SEQ_LENGTH} \
--max-position-embeddings ${SEQ_LENGTH} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--train-iters ${TRAIN_STEPS} \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
${CHECKPOINT_ARGS} \
--save ${CHECKPOINT_SAVE_DIR} \
--data-path ${DATA_PATH} \
--split 1,0,0 \
--distributed-backend nccl \
--init-method-std 0.02 \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--lr-decay-iters ${LR_DECAY_ITERS} \
--weight-decay ${WEIGHT_DECAY} \
--clip-grad ${GRAD_CLIP} \
--lr-warmup-iters ${LR_WARMUP_STEPS} \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \
--log-interval 1 \
--eval-interval 1000000000 \
--eval-iters 1 \
--bf16 \
--untie-embeddings-and-output-weights \
--no-position-embedding \
--position-embedding-type rope \
--disable-bias-linear \
--use-mcore-models \
--normalization RMSNorm \
--norm-epsilon 1e-5 \
--no-masked-softmax-fusion \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--swiglu \
--use-flash-attn \
--recompute-activations \
--recompute-granularity "selective" \
--attention-softmax-in-fp32 \
--transformer-impl "transformer_engine" \
--use-mpi \
--use-z-loss \
--wandb-name ${JOB_NAME} \
--wandb-project "mdx-test" \
--wandb-entity "llm-jp"
109 changes: 109 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/bin/bash
#
# Megatron installation script for pretrain jobs on the llm-jp cluster
#
# Usage:
# 1. Set the working directory to the directory this file is located.
# 2. Run `bash install.sh TARGET_DIR` with setting TARGET_DIR to the actual path.
#
# CAUTION:
# DO NOT change the content of this file and any other materials in the installer
# directory while the installation is being processed.

set -eux -o pipefail

if [ $# -ne 1 ]; then
>&2 echo Usage: bash install.sh TARGET_DIR
exit 1
fi

INSTALLER_DIR=$(pwd)
TARGET_DIR=$1; shift

>&2 echo INSTALLER_DIR=$INSTALLER_DIR
>&2 echo TARGET_DIR=$TARGET_DIR

mkdir ${TARGET_DIR}
pushd ${TARGET_DIR}

# copy basic scripts
cp -a ${INSTALLER_DIR}/{install.sh,requirements.txt,scripts,example} .

source scripts/environment.sh

# record current environment variables
set > installer_envvar.log

# src is used to store all resources for from-scratch builds
mkdir src
pushd src

# install Python
git clone https://github.com/python/cpython -b v${PRETRAIN_PYTHON_VERSION}
pushd cpython
./configure --prefix="${TARGET_DIR}/python" --enable-optimizations
make -j 64
make install
popd

popd # src

# prepare venv
python/bin/python3 -m venv venv
source venv/bin/activate
python -m pip install --no-cache-dir -U pip

# install PyTorch
python -m pip install \
--no-cache-dir \
--find-links https://download.pytorch.org/whl/torch_stable.html \
torch==${PRETRAIN_TORCH_VERSION}+cu${PRETRAIN_CUDA_VERSION_SHORT} \
torchvision==${PRETRAIN_TORCHVISION_VERSION}+cu${PRETRAIN_CUDA_VERSION_SHORT}

# install other requirements
python -m pip install --no-cache-dir -U -r requirements.txt

pushd src

# install apex
git clone --recurse-submodules https://github.com/NVIDIA/apex -b ${PRETRAIN_APEX_VERSION}
pushd apex
python -m pip install \
-v \
--no-cache-dir \
--no-build-isolation \
--config-settings "--build-option=--cpp_ext" \
--config-settings "--build-option=--cuda_ext" \
./
popd

# install transformer engine
# NOTE(odashi):
# This implicitly installs flash-attn with their recommended version.
# If the auto-installed flash-attn causes some problems, we need to re-install it.
python -m pip install \
-v \
--no-cache-dir \
--no-build-isolation \
git+https://github.com/NVIDIA/TransformerEngine.git@v${PRETRAIN_TRANSFORMER_ENGINE_VERSION}

# download our Megatron and build helper library
git clone https://github.com/llm-jp/Megatron-LM -b ${PRETRAIN_MEGATRON_TAG}
pushd Megatron-LM/megatron/core/datasets/
# NOTE(odashi):
# Original makefile in the above directory uses the system's (or pyenv's) python3-config.
# But we need to invoke python3-config installed on our target directory.
MEGATRON_HELPER_CPPFLAGS=(
-O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
$(python -m pybind11 --includes)
)
MEGATRON_HELPER_EXT=$(${TARGET_DIR}/python/bin/python3-config --extension-suffix)
g++ ${MEGATRON_HELPER_CPPFLAGS[@]} helpers.cpp -o helpers${MEGATRON_HELPER_EXT}
popd

# download our tokeniser
# Tokenizer
git clone https://github.com/llm-jp/llm-jp-tokenizer -b ${PRETRAIN_TOKENIZER_TAG}

popd # src
popd # ${TARGET_DIR}
20 changes: 20 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
accelerate==0.33.0
cmake==3.30.1
deepspeed==0.14.4
einops==0.8.0
mpi4py==4.0.0
ninja==1.11.1.1
nltk==3.8.1
numpy==1.26.4
packaging==24.1
pybind11==2.13.1
regex==2024.7.24
safetensors==0.4.3
sentencepiece==0.2.0
six==1.16.0
tensorboard==2.17.0
tensorstore==0.1.63
transformers==4.43.3
wandb==0.18.5
wheel==0.43.0
zarr==2.18.2
17 changes: 17 additions & 0 deletions pretrain/installers/v3-megatron-llm-jp/scripts/environment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
# List of environment variables and module loads for pretrain tasks

export PRETRAIN_CUDA_VERSION_MAJOR=11
export PRETRAIN_CUDA_VERSION_MINOR=8
export PRETRAIN_CUDA_VERSION=${PRETRAIN_CUDA_VERSION_MAJOR}.${PRETRAIN_CUDA_VERSION_MINOR}
export PRETRAIN_CUDA_VERSION_SHORT=${PRETRAIN_CUDA_VERSION_MAJOR}${PRETRAIN_CUDA_VERSION_MINOR}

export PRETRAIN_PYTHON_VERSION=3.10.14
export PRETRAIN_TORCH_VERSION=2.3.1
export PRETRAIN_TORCHVISION_VERSION=0.18.1
export PRETRAIN_APEX_VERSION=24.04.01
export PRETRAIN_TRANSFORMER_ENGINE_VERSION=1.4
export PRETRAIN_MEGATRON_TAG=nii-geniac
# Ensure the appropriate Huggingface tokenizer is included
# https://github.com/llm-jp/scripts/pull/12#discussion_r1708415209
export PRETRAIN_TOKENIZER_TAG=v3.0b2
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

# NCCL settings
export NCCL_IB_GID_INDEX=3
export NCCL_IB_TC=106