forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c7549c4
commit df61e60
Showing
30 changed files
with
2,461 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
checkpoints/ | ||
data-cache/ | ||
tensorboard/ | ||
triton-cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM nvcr.io/nvidia/pytorch:23.12-py3 | ||
|
||
RUN pip uninstall -y causal-conv1d triton && \ | ||
pip install causal-conv1d==1.2.2.post1 sentencepiece==0.1.99 triton==2.1.0 flask-restful | ||
|
||
WORKDIR /tmp | ||
|
||
RUN git clone https://github.com/state-spaces/mamba.git && \ | ||
cd mamba && \ | ||
git checkout v2.0.3 && \ | ||
python setup.py install && \ | ||
cd .. && \ | ||
rm -rf mamba | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Mamba-based Language Models | ||
|
||
## Introduction | ||
|
||
This document is an entrypoint into the code used for | ||
<em>[An Empirical Study of Mamba-based Language Models](https://arxiv.org/abs/2406.07887)</em>. | ||
|
||
We are releasing the parameters for some of the models described in that | ||
technical report via | ||
[HuggingFace](https://huggingface.co/collections/nvidia/ssms-666a362c5c3bb7e4a6bcfb9c). | ||
|
||
## Installation | ||
|
||
Create and run a Docker container using the [Dockerfile](./Dockerfile). | ||
|
||
``` | ||
docker build -t your_image_name:your_tag . | ||
docker run --gpus all -it --rm \ | ||
-v /path/to/megatron:/workspace/megatron \ | ||
-v /path/to/dataset:/workspace/dataset \ | ||
-v /path/to/checkpoints:/workspace/checkpoints \ | ||
-w /workspace/megatron/examples/mamba \ | ||
your_image_name:your_tag | ||
``` | ||
|
||
## Train | ||
|
||
[`train.sh`](./train.sh) is an example pretraining script, showing how to run on | ||
a single node. Select between 800M-scale and 8B-scale models by setting the | ||
`MODEL_SCALE` variable. The 8B-scale hybrid model architecture is the same as | ||
the one described in the technical report. | ||
|
||
## Text Generation | ||
|
||
Use [`run_text_gen_server_8b.sh`](./run_text_gen_server_8b.sh) to start a text | ||
generation server using an 8B hybrid checkpoint. This is configured to run the | ||
8B hybrid model described in the technical report, with tensor model parallel | ||
set to 1. | ||
|
||
The arguments in the script will need to be changed if using a checkpoint with a | ||
different model parallel configuration or other differences, such as model | ||
architecture. For example, to run the 8B pure Mamba-2 model, change | ||
`--hybrid-attention-ratio` and `--hybrid-mlp-ratio` to 0.0, or remove them. | ||
|
||
Use [`run_text_gen_server_8b_gpt3.sh`](./run_text_gen_server_8b_gpt3.sh) to start | ||
a text generation server using the 8B reference Transformer checkpoint. | ||
|
||
## Checkpoint Formats | ||
|
||
For inference, the model must be configured to match the checkpoint file used, | ||
including the hybrid layer configuration and model parallel configuration. | ||
|
||
If you need to convert a hybrid checkpoint file to a different tensor parallel | ||
or pipeline parallel size, use | ||
[the hybrid conversion script](../../tools/checkpoint/hybrid_conversion.py). | ||
There is an example run command at the end of that file. | ||
|
||
Before running that script, you will need to set `PYTHONPATH` to include the | ||
root directory of your Megatron-LM repository clone. | ||
|
||
``` | ||
export PYTHONPATH=<path-to-megatron>:PYTHONPATH | ||
``` | ||
|
||
## Hybrid Options | ||
|
||
`--hybrid-attention-ratio ATT` specifies a target ratio of attention layers | ||
to total layers. For example, 4 attention layers out of 48 total layers is | ||
specified by `--hybrid-attention-ratio 0.08`. | ||
|
||
`--hybrid-mlp-ratio MLP` specifies a target ratio of MLP layers to total | ||
layers. For example, 24 MLP layers out of 48 total layers is specified by | ||
`--hybrid-mlp-ratio 0.5`. | ||
|
||
* (`ATT` + `MLP`) must be less than or equal to 1.0. | ||
* (1.0 - `ATT` - `MLP`) is the hybrid mamba ratio, the ratio of mamba layers to | ||
total layers. | ||
* `ATT` = `MLP` = 0 is a pure Mamba model. | ||
* `ATT` = `MLP` = 0.5 is a transfomer model. | ||
|
||
If either `ATT` or `MLP` is greater than 0.0 or if `--hybrid-override-pattern` | ||
is specified, the logfile will include information about the hybrid layer | ||
pattern used. `--hybrid-override-pattern` can be used to specify a different | ||
pattern than the default, algorithmically-generated one. | ||
|
||
## Mamba vs Mamba-2 | ||
|
||
This codebase currently only supports Mamba-2, and not the original version of | ||
Mamba. However, the | ||
[fixed snapshot of the code used for the technical report](https://github.com/NVIDIA/Megatron-LM/tree/ssm/examples/mamba) | ||
can be configured to run the original version of Mamba. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/bin/bash | ||
|
||
# Use: ./run_text_gen_server_8b.sh <checkpoint-path> <tokenizer-path> | ||
# To launch the client: python ../../tools/text_generation_cli.py <URL-provided-by-server> | ||
|
||
CHECKPOINT_PATH=$1 | ||
TOKENIZER_PATH=$2 | ||
|
||
DISTRIBUTED_ARGS="--nproc_per_node 1 \ | ||
--nnodes 1 \ | ||
--node_rank 0 \ | ||
--master_addr localhost \ | ||
--master_port 6000" | ||
|
||
export NCCL_IB_SL=1 | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
export NCCL_IB_TIMEOUT=19 | ||
export NCCL_IB_QPS_PER_CONNECTION=4 | ||
|
||
export TRITON_CACHE_DIR="./triton-cache/" | ||
export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" | ||
|
||
torchrun $DISTRIBUTED_ARGS ../../tools/run_mamba_text_generation_server.py \ | ||
--tensor-model-parallel-size 1 \ | ||
--pipeline-model-parallel-size 1 \ | ||
--untie-embeddings-and-output-weights \ | ||
--num-layers 56 \ | ||
--hidden-size 4096 \ | ||
--load ${CHECKPOINT_PATH} \ | ||
--num-attention-heads 32 \ | ||
--group-query-attention \ | ||
--num-query-groups 8 \ | ||
--hybrid-attention-ratio 0.08 \ | ||
--hybrid-mlp-ratio 0.5 \ | ||
--attention-dropout 0.0 \ | ||
--hidden-dropout 0.0 \ | ||
--disable-bias-linear \ | ||
--normalization RMSNorm \ | ||
--seq-length 4096 \ | ||
--max-position-embeddings 4096 \ | ||
--position-embedding-type none \ | ||
--tokenizer-type GPTSentencePieceTokenizer \ | ||
--tokenizer-model ${TOKENIZER_PATH} \ | ||
--distributed-backend nccl \ | ||
--distributed-timeout-minutes 1440 \ | ||
--bf16 \ | ||
--micro-batch-size 1 \ | ||
--use-mcore-models \ | ||
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ | ||
--seed 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#!/bin/bash | ||
|
||
# Use: ./run_text_gen_server_8b_gpt3.sh <checkpoint-path> <tokenizer-path> | ||
# To launch the client: python ../../tools/text_generation_cli.py <URL-provided-by-server> | ||
|
||
CHECKPOINT_PATH=$1 | ||
TOKENIZER_PATH=$2 | ||
|
||
DISTRIBUTED_ARGS="--nproc_per_node 1 \ | ||
--nnodes 1 \ | ||
--node_rank 0 \ | ||
--master_addr localhost \ | ||
--master_port 6000" | ||
|
||
export NCCL_IB_SL=1 | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
export NCCL_IB_TIMEOUT=19 | ||
export NCCL_IB_QPS_PER_CONNECTION=4 | ||
|
||
torchrun $DISTRIBUTED_ARGS ../../tools/run_text_generation_server.py \ | ||
--tensor-model-parallel-size 1 \ | ||
--pipeline-model-parallel-size 1 \ | ||
--use-flash-attn \ | ||
--apply-layernorm-1p \ | ||
--untie-embeddings-and-output-weights \ | ||
--num-layers 32 \ | ||
--hidden-size 4096 \ | ||
--load ${CHECKPOINT_PATH} \ | ||
--num-attention-heads 32 \ | ||
--attention-dropout 0.0 \ | ||
--hidden-dropout 0.0 \ | ||
--disable-bias-linear \ | ||
--seq-length 4096 \ | ||
--max-position-embeddings 4096 \ | ||
--position-embedding-type rope \ | ||
--rotary-percent 0.5 \ | ||
--squared-relu \ | ||
--tokenizer-type GPTSentencePieceTokenizer \ | ||
--tokenizer-model ${TOKENIZER_PATH} \ | ||
--distributed-backend nccl \ | ||
--distributed-timeout-minutes 1440 \ | ||
--bf16 \ | ||
--micro-batch-size 1 \ | ||
--use-mcore-models \ | ||
--transformer-impl local \ | ||
--seed 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#!/bin/bash | ||
|
||
# Use: ./train.sh <data-path> <tokenizer-path> | ||
|
||
MODEL_SCALE="800M" # or "8B" | ||
|
||
case "${MODEL_SCALE}" in | ||
"800M") | ||
TENSOR_MODEL_PARALLEL_SIZE=1 | ||
NUM_LAYERS=48 | ||
HIDDEN_SIZE=1024 | ||
NUM_ATTENTION_HEADS=16 | ||
GLOBAL_BATCH_SIZE=32 | ||
;; | ||
"8B") | ||
TENSOR_MODEL_PARALLEL_SIZE=4 | ||
NUM_LAYERS=56 | ||
HIDDEN_SIZE=4096 | ||
NUM_ATTENTION_HEADS=32 | ||
GLOBAL_BATCH_SIZE=8 | ||
;; | ||
*) | ||
echo "Invalid version specified" | ||
exit 1 | ||
;; | ||
esac | ||
|
||
DATA_PATH=$1 | ||
TOKENIZER_PATH=$2 | ||
|
||
export NCCL_IB_SL=1 | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
export NCCL_IB_TIMEOUT=19 | ||
export NCCL_IB_QPS_PER_CONNECTION=4 | ||
|
||
CHECKPOINT_DIR="./checkpoints" | ||
DATACACHE_DIR="./data-cache" | ||
TENSORBOARD_DIR="./tensorboard" | ||
|
||
mkdir -p ${CHECKPOINT_DIR} | ||
mkdir -p ${DATACACHE_DIR} | ||
mkdir -p ${TENSORBOARD_DIR} | ||
|
||
export TRITON_CACHE_DIR="./triton-cache/" | ||
export TRITON_CACHE_MANAGER="megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager" | ||
|
||
SEQ_LEN=4096 | ||
TRAIN_SAMPLES=73242188 # 300B tokens / 4096 | ||
LR_WARMUP_SAMPLES=50000 | ||
LR_DECAY_SAMPLES=73192188 # TRAIN_SAMPLES - LR_WARMUP_SAMPLES | ||
|
||
options=" \ | ||
--tensor-model-parallel-size ${TENSOR_MODEL_PARALLEL_SIZE} \ | ||
--sequence-parallel \ | ||
--pipeline-model-parallel-size 1 \ | ||
--use-distributed-optimizer \ | ||
--overlap-param-gather \ | ||
--overlap-grad-reduce \ | ||
--untie-embeddings-and-output-weights \ | ||
--init-method-std 0.02 \ | ||
--position-embedding-type none \ | ||
--num-layers ${NUM_LAYERS} \ | ||
--hidden-size ${HIDDEN_SIZE} \ | ||
--num-attention-heads ${NUM_ATTENTION_HEADS} \ | ||
--group-query-attention \ | ||
--num-query-groups 8 \ | ||
--hybrid-attention-ratio 0.08 \ | ||
--hybrid-mlp-ratio 0.5 \ | ||
--seq-length ${SEQ_LEN} \ | ||
--max-position-embeddings ${SEQ_LEN} \ | ||
--train-samples ${TRAIN_SAMPLES} \ | ||
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \ | ||
--lr-decay-samples ${LR_DECAY_SAMPLES} \ | ||
--save ${CHECKPOINT_DIR} \ | ||
--load ${CHECKPOINT_DIR} \ | ||
--data-path ${DATA_PATH} \ | ||
--data-cache-path ${DATACACHE_DIR} \ | ||
--split 99,1,0 \ | ||
--tokenizer-type GPTSentencePieceTokenizer \ | ||
--tokenizer-model ${TOKENIZER_PATH} \ | ||
--distributed-backend nccl \ | ||
--micro-batch-size 4 \ | ||
--global-batch-size ${GLOBAL_BATCH_SIZE} \ | ||
--lr 2.5e-4 \ | ||
--min-lr 2.5e-5 \ | ||
--lr-decay-style cosine \ | ||
--weight-decay 0.1 \ | ||
--clip-grad 1.0 \ | ||
--attention-dropout 0.0 \ | ||
--hidden-dropout 0.0 \ | ||
--disable-bias-linear \ | ||
--normalization RMSNorm \ | ||
--adam-beta1 0.9 \ | ||
--adam-beta2 0.95 \ | ||
--log-interval 10 \ | ||
--save-interval 2000 \ | ||
--eval-interval 2000 \ | ||
--eval-iters 32 \ | ||
--bf16 \ | ||
--use-mcore-models \ | ||
--spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ | ||
--no-create-attention-mask-in-dataloader \ | ||
--tensorboard-dir ${TENSORBOARD_DIR}" | ||
|
||
torchrun --nproc_per_node 8 ../../pretrain_mamba.py ${options} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .mamba_model import MambaModel |
Oops, something went wrong.