Skip to content

Latest commit

 

History

History
 
 

medusa

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

Medusa Decoding

This document shows how to build and run a model using Medusa decoding(Github, BLOG) in TensorRT-LLM on single GPU, single node multiple GPU.

Overview

Different from other models, Medusa decoding need a base model and Medusa heads.

The TensorRT-LLM Medusa Decoding implementation can be found in tensorrt_llm/models/medusa/model.py, which actually adds MedusaHeads to a base model.

Support Matrix

  • GPU Compute Capability >= 8.0 (Ampere or newer)
  • FP16
  • BF16
  • PAGED_KV_CACHE
  • Tensor Parallel

Usage

The TensorRT-LLM Medusa example code is located in examples/medusa. There is one convert_checkpoint.py file to convert and build the TensorRT engine(s) needed to run models with Medusa decoding support. In our example, we use the model from huggingface FasterDecoding/medusa-vicuna-7b-v1.3, which is a LLAMA based model.

Build TensorRT engine(s)

Get the weights by downloading base model vicuna-7b-v1.3 and Medusa Heads medusa-vicuna-7b-v1.3 from HF.

pip install -r requirements.txt

git lfs install
git clone https://huggingface.co/lmsys/vicuna-7b-v1.3
https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3

We use convert_checkpoint.py script to convert the model for Medusa decoding into TensorRT-LLM checkpoint format. Here we also add --fixed_num_medusa_heads 4 as medusa_num_heads is 2 in config.json of medusa-vicuna-7b-v1.3 but it actually has 4.

Here is the example:

# Convert and Build Medusa decoding support for vicuna-7b-v1.3
python convert_checkpoint.py --model_dir ./vicuna-7b-v1.3 \
                            --medusa_model_dir medusa-vicuna-7b-v1.3 \
                            --output_dir ./tllm_checkpoint_1gpu_medusa \
                            --dtype float16 \
                            --fixed_num_medusa_heads 4

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \
             --output_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
             --gemm_plugin float16 \
             --max_batch_size 8

# Convert and Build Medusa decoding support for vicuna-13b-v1.3 with 4-way tensor parallelism.
python convert_checkpoint.py --model_dir ./vicuna-7b-v1.3 \
                            --medusa_model_dir medusa-vicuna-7b-v1.3 \
                            --output_dir ./tllm_checkpoint_1gpu_medusa \
                            --dtype float16 \
                            --fixed_num_medusa_heads 4 \
                            --tp_size 4 \
                            --workers 4

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_medusa \
             --output_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
             --gemm_plugin float16 \
             --max_batch_size 8

Run

To run a TensorRT-LLM model with Medusa decoding support, we can use ../run.py script, with an additional argument --medusa_choices. The --medusa_choices is of type list[list[int]], And also the built engine with Medusa decoding support.

Note: Medusa decoding is only supported by Python runtime now. So need --use_py_session.

Note: Medusa decoding only supporting greedy decoding temperature=1.0 now. So also need --temperature 1.0.

# Medusa decoding using vicuna-7b-v1.3 model with 1 GPU
python ../run.py --engine_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
                 --tokenizer_dir ./vicuna-7b-v1.3/ \
                 --max_output_len=100 \
                 --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                 --use_py_session \
                 --temperature 1.0 \
                 --input_text "Once upon"

# Medusa decoding using vicuna-13b-v1.3 with 4 GPUs
mpirun -np 4 --allow-run-as-root --oversubscribe \
    python ../run.py --engine_dir ./tmp/medusa/13B/trt_engines/fp16/4-gpu/ \
                     --tokenizer_dir ./vicuna-13b-v1.3/ \
                     --max_output_len=100 \
                     --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                     --use_py_session \
                     --temperature 1.0 \
                     --input_text "Once upon"

And you will see output like this if run successfully:

......
Input [Text 0]: "<s> Once upon"
Output [Text 0 Beam 0]: "a time, there was a young girl who loved to read. She would spend hours in the library, devouring books of all genres. She had a special love for fairy tales, and would often dream of living in a magical world where she could meet princes and princesses, and have adventures with talking animals.
One day, while she was reading a book, she came across a passage that spoke to her heart. It said, "You are the author of"

Summarization using Medusa decoding

# Medusa decoding using vicuna-7b-v1.3 model with 1 GPU
python ../summarize.py --engine_dir ./tmp/medusa/7B/trt_engines/fp16/1-gpu/ \
                       --hf_model_dir ./vicuna-7b-v1.3/ \
                       --tokenizer_dir ./vicuna-7b-v1.3/ \
                       --test_trt_llm \
                       --data_type fp16 \
                       --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                       --use_py_session \
                       --temperature 1.0 \
                       --batch_size 1

# Medusa decoding using vicuna-13b-v1.3 with 4 GPUs
mpirun -np 4 --allow-run-as-root --oversubscribe \
    python ../summarize.py --engine_dir ./tmp/medusa/13B/trt_engines/fp16/4-gpu/ \
                           --hf_model_dir ./vicuna-13b-v1.3/ \
                           --tokenizer_dir ./vicuna-13b-v1.3/ \
                           --test_trt_llm \
                           --data_type fp16 \
                           --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                           --use_py_session \
                           --temperature 1.0 \
                           --batch_size 1