Skip to content

Commit

Permalink
Llama accelerate tutorial (NVIDIA#720)
Browse files Browse the repository at this point in the history
* tutorial and doc fixes

Signed-off-by: Sudhakar Singh <[email protected]>

* remove extra code

Signed-off-by: Sudhakar Singh <[email protected]>

* fix typos

Signed-off-by: Sudhakar Singh <[email protected]>

---------

Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
sudhakarsingh27 authored and ksivaman committed Apr 3, 2024
1 parent fa61eb5 commit a84d021
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 49 deletions.
7 changes: 4 additions & 3 deletions docs/examples/te_llama/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from transformers.utils.hub import get_checkpoint_shard_files

@contextmanager
def replace_decoder(te_decodder_cls):
def replace_decoder(te_decoder_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
try:
yield
finally:
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, config, *args, **kwargs):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
Expand Down Expand Up @@ -84,7 +85,7 @@ class is monkey-patched with `TELlamaDecoderLayer` class before
"""

def __new__(cls, config: LlamaConfig):
with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
llama_for_causal_lm = LlamaForCausalLM(config)
return llama_for_causal_lm

Expand Down
74 changes: 38 additions & 36 deletions docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "1f37565e",
"id": "2cac9d39",
"metadata": {},
"source": [
"# Accelerating a Hugging Face Llama 2 model with Transformer Engine\n",
Expand All @@ -11,14 +11,14 @@
"\n",
"<b>Goal</b>\n",
"\n",
"This tutorial showcases how accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"This tutorial showcases how to accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"id": "ab4c0b82",
"id": "401f7fb1",
"metadata": {},
"source": [
"## Dependencies for this tutorial\n",
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "markdown",
"id": "466ff515",
"id": "33bdb5fe",
"metadata": {},
"source": [
"## Table of contents\n",
Expand All @@ -53,7 +53,7 @@
},
{
"cell_type": "markdown",
"id": "8e84bcaa",
"id": "7645f176",
"metadata": {},
"source": [
"## From \"Transformer\" to \"Llama\" \n",
Expand Down Expand Up @@ -89,7 +89,7 @@
},
{
"cell_type": "markdown",
"id": "e31303c7",
"id": "d0cfa787",
"metadata": {},
"source": [
"## Hugging Face's `LlamaModel`\n",
Expand Down Expand Up @@ -166,7 +166,7 @@
},
{
"cell_type": "markdown",
"id": "686df4ef",
"id": "f4f21369",
"metadata": {},
"source": [
"## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
Expand All @@ -190,7 +190,7 @@
},
{
"cell_type": "markdown",
"id": "107a8146",
"id": "24a8d0a5",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
Expand All @@ -206,16 +206,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "975f9184",
"execution_count": 1,
"id": "e36ff380",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 289 milliseconds\n"
"Average time taken per step: 315 milliseconds\n"
]
}
],
Expand Down Expand Up @@ -247,19 +247,19 @@
},
{
"cell_type": "markdown",
"id": "c2d5b174",
"id": "a64f0f33",
"metadata": {},
"source": [
"Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 289 | 1 |"
"| HF (baseline) | BF16 | 315 | 1 |"
]
},
{
"cell_type": "markdown",
"id": "a7d436bf",
"id": "d9898383",
"metadata": {},
"source": [
"## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
Expand Down Expand Up @@ -322,6 +322,7 @@
" normalization=\"RMSNorm\",\n",
" activation=\"swiglu\",\n",
" attn_input_format=\"bshd\",\n",
" num_gqa_groups=config.num_key_value_heads,\n",
" )\n",
" te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n",
" self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n",
Expand All @@ -339,10 +340,11 @@
"8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n",
"9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n",
"10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n",
"11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. \n",
"11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n",
"12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n",
"\n",
"\n",
"Further, note that `RotaryPositionEmbedding` is defined as part of the TE's `TransformerLayer` itself since it expects this rope cache if RoPE is used in the model. \n",
"Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n",
"\n",
"Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n",
"```\n",
Expand Down Expand Up @@ -422,12 +424,12 @@
"\n",
"```\n",
"@contextmanager\n",
"def replace_decoder(te_decodder_cls):\n",
"def replace_decoder(te_decoder_cls):\n",
" \"\"\"\n",
" Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n",
" \"\"\"\n",
" original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls\n",
" transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n",
" try:\n",
" yield\n",
" finally:\n",
Expand All @@ -446,7 +448,7 @@
" \"\"\"\n",
"\n",
" def __new__(cls, config: LlamaConfig):\n",
" with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):\n",
" with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n",
" llama_for_causal_lm = LlamaForCausalLM(config)\n",
" return llama_for_causal_lm\n",
".\n",
Expand Down Expand Up @@ -530,15 +532,15 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "48dc8935",
"id": "4974b738",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 242 milliseconds\n"
"Average time taken per step: 252 milliseconds\n"
]
}
],
Expand Down Expand Up @@ -570,20 +572,20 @@
},
{
"cell_type": "markdown",
"id": "3c3d228a",
"id": "85c78c7f",
"metadata": {},
"source": [
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **19%** even when using only BF16 precision!\n",
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **25%** even when using only BF16 precision!\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 289 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 242 | 1.19 |"
"| HF (baseline) | BF16 | 315 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |"
]
},
{
"cell_type": "markdown",
"id": "b92d6792",
"id": "e2fb88e9",
"metadata": {},
"source": [
"## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
Expand All @@ -608,16 +610,16 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6bba7cc1",
"execution_count": 1,
"id": "8f2b752e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 231 milliseconds\n"
"Average time taken per step: 226 milliseconds\n"
]
}
],
Expand Down Expand Up @@ -649,27 +651,27 @@
},
{
"cell_type": "markdown",
"id": "602239d7",
"id": "67ec126c",
"metadata": {},
"source": [
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 289 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 242 | 1.19 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 231 | 1.25 |\n",
"| HF (baseline) | BF16 | 315 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 226 | 1.39 |\n",
"\n",
"\n",
"After turning on FP8 precision, we get even more speedup of **25%**!"
"After turning on FP8 precision, we get even more speedup of almost **40%**!"
]
},
{
"cell_type": "markdown",
"id": "372867d5",
"id": "41b80b0f",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides speedup over Hugging Face's native Llama 2 implementation. This needs careful initializing of model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 implementation. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
]
}
],
Expand Down
33 changes: 23 additions & 10 deletions docs/examples/te_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(self):
self.batch_size = 8
self.max_seq_length = 256
self.gradient_accumulation_steps = 1
self.num_warmup_steps=5
self.num_training_steps=10


hyperparams = HyperParameters()

Expand Down Expand Up @@ -132,11 +134,9 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer,
optimizer.zero_grad()
train_dataloader = enumerate(train_dataloader)

time_vals = []

for _ in range(hyperparams.num_training_steps):
# Warmup iters
for _ in range(hyperparams.num_warmup_steps):
step, batch = next(train_dataloader)
start_time = time.time()
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
Expand All @@ -146,15 +146,28 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer,
lr_scheduler.step()
optimizer.zero_grad()

end_time = time.time()
total_time = end_time - start_time
time_vals.append(total_time)
# Get the timers ready
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()

start.record()
# Training iters
for _ in range(hyperparams.num_training_steps):
step, batch = next(train_dataloader)
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
torch.cuda.synchronize()
end.record()
accelerator.end_training()

# ignore the first couple of time vals
time_vals = time_vals[2:]
print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(sum(time_vals)/len(time_vals)) * 1000:.0f} milliseconds")
print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds")

def restart_jupyter_notebook():
# Try restarting the Jupyter kernel
Expand Down

0 comments on commit a84d021

Please sign in to comment.