forked from huggingface/huggingface-llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
assisted_decoding.py
45 lines (38 loc) · 2.27 KB
/
assisted_decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# This example showcases using two Llama 3.1 checkpoints, one big, the 405B, and one small, the 8B, in order
# to do assisted generation.
#
# In brief terms, assisted generation makes use of a smaller model to generate sensible outputs, which are then
# validated or invalidated by the larger model.
#
# We recommend this blogpost to dive into assisted generation: https://huggingface.co/blog/assisted-generation
#
# In order to run this example, you will need enough memory for both models.
#
# The result **should** match the original model's generation.
# CAVEAT 1: sampling ruins this property, even with seeding (because the assistant model consumes the rng state as well).
# CAVEAT 2: due to the nature of fp ops, there are tiny fluctuations in the logits, which may lead to different text results. There 2 properties should be true, nonetheless: a) the quality of the generated text is the same, and b) the logits on the first mismatched token are very close to each other.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import torch
WARMUP = 2 # number of non-timed warmup runs
MAX_NEW_TOKENS = 10
DO_SAMPLE = True
ATOL = 1e-6 # ~1e-6 for fp32, up to ~1e-3 for 16 bit vars [these are NORMALIZED logits, post-softmax]; see caveats below
TORCH_DTYPE = torch.float32
PROMPT = "Alice and Bob "
CHECKPOINT = "meta-llama/Meta-Llama-3-405B" # <--- big llama here
ASSISTED_CHECKPOINT = "meta-llama/Meta-Llama-3.1-8B" # <--- small llama here
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, device_map="auto", torch_dtype=TORCH_DTYPE)
assistant_model = AutoModelForCausalLM.from_pretrained(ASSISTED_CHECKPOINT, device_map="auto", torch_dtype=TORCH_DTYPE)
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
inputs = tokenizer(PROMPT, return_tensors="pt").to(model.device)
# Warmup + store logits for later comparison if needed
for _ in range(WARMUP):
model.generate(**inputs, assistant_model=assistant_model)
start = time.time()
assisted_outputs = model.generate(**inputs, assistant_model=assistant_model)
end = time.time()
assisted_gen_text = tokenizer.batch_decode(assisted_outputs, skip_special_tokens=True)
print(assisted_gen_text)
print(f"\nAssisted time taken: {end - start:.2f}s")