Skip to content

Commit 61fe126

Browse files
authored
Merge pull request #10 from okoge-kaz/feature/gemma
Support Gemma-2, Llama-3-Instruct and Fix some features
2 parents f7da9a0 + f3a9b95 commit 61fe126

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1577
-1575
lines changed

.vscode/settings.json

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"pbar",
4848
"peft",
4949
"plamo",
50+
"probs",
5051
"psutil",
5152
"pubmed",
5253
"samsum",

megatron_lm/megatron/core/datasets/Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
22
CPPFLAGS += $(shell python3 -m pybind11 --includes)
33
LIBNAME = helpers
4-
LIBEXT = $(shell ${PYENV_ROOT}/versions/3.10.12/bin/python3-config --extension-suffix)
4+
LIBEXT = $(shell ${PYENV_ROOT}/versions/3.11.9/bin/python3-config --extension-suffix)
55

66
default: $(LIBNAME)$(LIBEXT)
77

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
--find-links https://download.pytorch.org/whl/torch_stable.html
2-
torch==2.2.2+cu121
2+
torch==2.3.1+cu121
33

44
# huggingface
55
transformers>=4.41.1

scripts/abci/instruction/swallow-7b/swallow-7b-baseline.sh scripts/abci/instruction/Llama-3-8B/Llama-3-8B-instruct-v0.2.sh

+30-29
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
#!/bin/bash
22
#$ -l rt_AF=2
3-
#$ -l h_rt=1:00:00:00
3+
#$ -l h_rt=0:01:00:00
44
#$ -j y
5-
#$ -o outputs/instruction/swallow-7b/
5+
#$ -o outputs/instruction/Llama-3-8B/
66
#$ -cwd
77

88
# module load
99
source /etc/profile.d/modules.sh
10-
module load cuda/11.8/11.8.0
11-
module load cudnn/8.9/8.9.2
12-
module load nccl/2.16/2.16.2-1
10+
module use /bb/llm/gaf51275/modules/modulefiles
11+
12+
module load cuda/12.1/12.1.1
13+
module load cudnn/cuda-12.1/9.0.0
14+
module load nccl/2.20.5
1315
module load hpcx/2.12
16+
module load gcc/11.4.0
1417

1518
# swich virtual env
1619
source .env/bin/activate
@@ -44,54 +47,52 @@ while read -r line; do
4447
done <"$SGE_JOB_HOSTLIST" >"$HOSTFILE_NAME"
4548

4649
# training config
47-
SEQ_LENGTH=4096
50+
SEQ_LENGTH=8192
4851
DATA_PARALLEL_SIZE=$NUM_GPUS
4952

50-
MICRO_BATCH_SIZE=4
51-
GLOBAL_BATCH_SIZE=64
53+
MICRO_BATCH_SIZE=1
54+
GLOBAL_BATCH_SIZE=128
5255

5356
# optimizer config
54-
LR=2e-5
55-
MIN_LR=2e-6
57+
LR=1e-5
58+
MIN_LR=1e-6
5659
WEIGHT_DECAY=0.1
5760
GRAD_CLIP=1
5861

59-
# checkpoint & tokenizer
60-
TOKENIZER_MODEL=/bb/llm/gaf51275/llama/huggingface-checkpoint/Swallow-7b-hf/tokenizer.model
61-
CHECKPOINT_DIR=/bb/llm/gaf51275/llama/huggingface-checkpoint/Swallow-7b-hf
62-
CHECKPOINT_SAVE_DIR="/bb/llm/gaf51275/llama/checkpoints/Swallow-7b-VE-chat/baseline-lr_${LR}-minlr_${MIN_LR}"
62+
# checkpoint
63+
TOKENIZER_DIR=/groups/gag51395/hf-checkpoints/Meta-Llama-3-8B-Instruct
64+
CHECKPOINT_DIR=/groups/gag51395/hf-checkpoints/Meta-Llama-3-8B-Instruct
65+
CHECKPOINT_SAVE_DIR="/bb/llm/gaf51275/2024/checkpoints/Llama-3-8B-Instruct-v0.2/LR_${LR}_MINLR_${MIN_LR}_WD_${WEIGHT_DECAY}_GC_${GRAD_CLIP}"
6366

6467
mkdir -p ${CHECKPOINT_SAVE_DIR}
6568

6669
# dataset
67-
DATASET_DIR=/bb/llm/gaf51275/llama/finetuning/datasets/training/baseline
70+
DATASET_DIR=/groups/gag51395/datasets/instruction/2023-swallow/training/baseline
6871

6972
TRAIN_DATA_PATH=${DATASET_DIR}/train.jsonl
70-
VALID_DATA_PATH=${DATASET_DIR}/val.jsonl
73+
VALID_DATA_PATH=${DATASET_DIR}/train.jsonl
7174

7275
# job name
73-
JOB_NAME="Swallow-7b-VE-baseline-BS=${GLOBAL_BATCH_SIZE}-LR=${LR}-MINLR=${MIN_LR}"
76+
JOB_NAME="Llama-3-8B-instruct-v0.2-BS=${GLOBAL_BATCH_SIZE}-LR=${LR}-MINLR=${MIN_LR}-WD=${WEIGHT_DECAY}-GC=${GRAD_CLIP}"
7477

7578
# run
7679
mpirun -np $NUM_GPUS \
7780
--npernode $NUM_GPU_PER_NODE \
7881
-hostfile $HOSTFILE_NAME \
7982
-x MASTER_ADDR=$MASTER_ADDR \
8083
-x MASTER_PORT=$MASTER_PORT \
81-
-bind-to none -map-by slot \
84+
-bind-to none \
85+
-x PATH \
86+
-x LD_LIBRARY_PATH \
8287
-x PATH \
8388
python examples/finetuning.py \
8489
--seq-length ${SEQ_LENGTH} \
85-
--sliding-window-size ${SEQ_LENGTH} \
8690
--micro-batch-size ${MICRO_BATCH_SIZE} \
8791
--global-batch-size ${GLOBAL_BATCH_SIZE} \
88-
--hf-transformer-model-dir ${CHECKPOINT_DIR} \
89-
--tokenizer-type Llama2Tokenizer \
90-
--tokenizer-model ${TOKENIZER_MODEL} \
92+
--hf-transformer-model-dir ${TOKENIZER_DIR} \
9193
--instruction-train-data-path ${TRAIN_DATA_PATH} \
9294
--instruction-valid-data-path ${VALID_DATA_PATH} \
93-
--epoch 2 \
94-
--train-iters 500000 \
95+
--epoch 1 \
9596
--lr ${LR} \
9697
--min-lr ${MIN_LR} \
9798
--lr-decay-style cosine \
@@ -100,10 +101,10 @@ mpirun -np $NUM_GPUS \
100101
--optimizer adam \
101102
--adam-beta1 0.9 \
102103
--adam-beta2 0.95 \
103-
--adam-eps 1e-6 \
104+
--adam-eps 1e-8 \
104105
--save-interval 500 \
105-
--eval-interval 100 \
106-
--eval-iters 20 \
106+
--eval-interval 500 \
107+
--eval-iters 10 \
107108
--bf16 \
108109
--mixed-precision \
109110
--base-model ${CHECKPOINT_DIR} \
@@ -116,6 +117,6 @@ mpirun -np $NUM_GPUS \
116117
--instruction-tuning \
117118
--save-sampler-state \
118119
--use-mpi \
119-
--wandb-entity "prj-jalm" \
120-
--wandb-project "Llama-2-7b-instruct" \
120+
--wandb-entity "okoge" \
121+
--wandb-project "llm-recipes" \
121122
--wandb-name "${JOB_NAME}"

scripts/abci/instruction/swallow-13b/swallow-13b-dolly-oasst2-top1-imitation-2-3.sh

-121
This file was deleted.

scripts/abci/instruction/swallow-13b/swallow-13b-oasst2-top1-imitation2-3.sh

-121
This file was deleted.

0 commit comments

Comments
 (0)