forked from huggingface/huggingface-llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
torch_compile.py
48 lines (39 loc) · 1.88 KB
/
torch_compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# This example showcases how to leverage the 3.1 8B Instruct models using
# torch.compile to accelerate inference.
#
# You need CUDA and torch >= 2.3 in order to run this example.
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
device = "cuda"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Specify the max length (including both the prompt and the response)
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
# with sequence length = `max_length`. The longer the more you will re-use it
model.generation_config.max_length = 128
# without `torch.compile`: each call takes ~ 5.0 seconds (on A100 80G + torch 2.3)
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
print(response)
# `torch.compile(model, ...)` is not recommended as you compile callbacks
# and full generate. We recommend compiling only the forward for now.
# "reduce-overhead" will use cudagraphs.
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"
# with `torch.compile` (on A100 80G + torch 2.3)
# 1st call: ~ 90 seconds
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
# 2nd call: ~ 60 seconds
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
# 3nd call: ~ 1.5 seconds
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
print(response)