-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathspeculative decoding in LLM
94 lines (66 loc) · 6.36 KB
/
speculative decoding in LLM
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
From https://towardsdatascience.com/speculative-decoding-for-faster-inference-with-mixtral-8x7b-and-gemma-f5b1487f5714
In speculative decoding in LLM, there are two model. One model is "main model", it is LLM model, and the other model is "draft model", it is small model Relatively.
In the context of speculative decoding, the main model can be seen as supervising or overseeing the draft model's output.
Speculative decoding and its application in accelerating inference for large language models (LLMs).
Speculative decoding indeed offers a promising approach to balancing performance and computational efficiency,
especially for scenarios where larger LLMs deliver superior results but at the cost of reduced inference speed.
The essence of speculative decoding lies in leveraging a smaller LLM to generate token suggestions during inference,
which are then validated or corrected by a larger, more accurate LLM.
By doing so, speculative decoding aims to exploit the computational efficiency of the smaller model while benefiting from the superior capabilities of the larger model.
# Key factors influencing the effectiveness of speculative decoding include:
Difference in model sizes: The draft model (smaller LLM) should be significantly smaller in size compared to the main model (larger LLM).
Ideally, the main model should be several orders of magnitude larger than the draft model to ensure effective acceleration.
Similar architectures and training data: The main and draft models should have similar architectures and be trained on comparable data tokenized with the same tokenizer.
This ensures compatibility and reduces the need for extensive corrections during inference.
Validation of token suggestions: The draft model's token suggestions should be sufficiently accurate to minimize the corrective workload for the main model.
If a large portion of the draft model's suggestions requires extensive correction by the main model,
the efficiency gains of speculative decoding may diminish.
Memory consumption: Since speculative decoding involves running two models simultaneously during inference, it incurs additional memory overhead.
Optimizing memory consumption, such as through model quantization or careful selection of model pairs, is crucial for practical deployment,
particularly on resource-constrained hardware.
The experimentation and analysis provided demonstrate the nuanced performance trade-offs associated with speculative decoding across different pairs of LLMs.
While speculative decoding proves effective in certain configurations, its benefits may vary depending on factors such as model size, architecture, and tokenization.
Moreover, the practical feasibility of speculative decoding depends on the availability of appropriately sized draft models
and the computational resources required for simultaneous inference.
Overall, speculative decoding represents a valuable strategy for accelerating inference with large language models,
offering a balance between performance and computational efficiency.
However, careful consideration of model characteristics and thorough experimentation are necessary to realize its full potential in real-world applications.
#####
Speculative decoding is a technique used to accelerate inference for large language models (LLMs) by employing two models simultaneously during the decoding process.
In speculative decoding, a smaller LLM, referred to as the "draft model," generates token suggestions,
which are then validated or corrected by a larger, more accurate LLM, known as the "main model."
## How speculative decoding works based on the provided text:
Draft Model Generation: The draft model, which is smaller and faster, suggests tokens during the decoding process.
It generates potential translations or outputs based on the input sequence.
Main Model Validation: The token suggestions provided by the draft model are then passed to the main model for validation.
The main model checks the suggestions and corrects them if necessary based on its larger knowledge and accuracy.
Output Refinement: After the main model validates the token suggestions, the final output is determined.
If the draft model's suggestions are mostly accurate, the corrections required by the main model are minimal,
resulting in a faster overall inference process.
Efficiency Considerations: Speculative decoding can significantly accelerate inference if the draft model is substantially smaller and
if its token suggestions are mostly correct, minimizing the corrective workload for the main model.
However, it may lead to increased memory consumption due to running two models simultaneously during inference.
In summary, speculative decoding leverages a smaller, faster LLM to generate initial token suggestions, which are then refined by a larger,
more accurate LLM. By combining the efficiency of the draft model with the accuracy of the main model,
speculative decoding aims to speed up inference without sacrificing the quality of the outputs.
###################################### Example code ##############################################
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, BitsAndBytesConfig
set_seed(42) # For reproducibility
checkpoint = "meta-llama/Llama-2-13b-hf"
assistant_checkpoint = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", quantization_config=bnb_config)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, device_map="auto")
prompt = "Tell me about gravity."
model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
output = model.generate(**model_inputs, assistant_model=assistant_model, max_length=500)[0]
output_decoded = tokenizer.decode(output)
print(output_decoded)