Speculative Decoding: 3× Faster LLM Inference for Free
In the rapidly evolving world of artificial intelligence, there’s a constant push to make large language models (LLMs) faster without sacrificing the quality of their outputs. Imagine being able to generate text three times faster without any additional computational cost. Speculative decoding offers exactly this revolutionary leap forward, allowing us to maintain the integrity of LLM outputs while accelerating their generation.
“The future of AI is not just in making smarter models, but in making smart models work faster.”
— Unknown Visionary, 2023
The Core Intuition
Think of speculative decoding as akin to drafting a document with an assistant before having it approved by an expert. Initially, a smaller, more efficient model drafts several tokens—essentially making guesses about the sequence continuation. This draft is then verified in bulk by the original, larger model in a parallel process. If the larger model’s probabilities align closely enough with the draft’s predictions, these tokens are accepted.
This clever strategy hinges on leveraging the strengths of both speed and accuracy. The smaller model is like a nimble drafter, sacrificing some precision for swiftness, while the larger model is the meticulous inspector, ensuring that the overall narrative remains cohesive and accurate.
The Mathematics
Mathematically, speculative decoding hinges on the acceptance criterion:
\[\text{Accept token } x \text{ if } \frac{p_{\text{large}}(x)}{p_{\text{draft}}(x)} \geq U[0,1]\]where \(p_{\text{large}}(x)\) is the probability of the token according to the larger model, and \(p_{\text{draft}}(x)\) is the probability according to the draft model. The acceptance mechanism ensures that the overall distribution remains unchanged.
The expected number of accepted tokens \(E[\text{accepted}]\) can be derived as:
\[E[\text{accepted}] = \frac{1 - \alpha^{\gamma+1}}{1 - \alpha}\]where \(\alpha\) is the mean token acceptance rate, and \(\gamma\) is the number of tokens drafted by the smaller model. This formula highlights how, as the acceptance rate improves, speculative decoding can achieve impressive speed-ups while retaining entire model accuracy.
Architecture & Implementation
Here’s a look under the hood at how you might implement a speculative decoding loop in Python using PyTorch. This loop handles both the drafting and verifying process:
import torch
import torch.nn.functional as F
def speculative_decoding(draft_model, verify_model, input_tokens, gamma):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
draft_model.to(device)
verify_model.to(device)
input_tokens = input_tokens.to(device)
sequence = input_tokens
for _ in range(gamma):
with torch.no_grad():
draft_logits = draft_model(sequence)
draft_probs = F.softmax(draft_logits, dim=-1)
draft_tokens = torch.multinomial(draft_probs, num_samples=1)
sequence = torch.cat([sequence, draft_tokens], dim=-1)
with torch.no_grad():
verify_logits = verify_model(sequence)
verify_probs = F.softmax(verify_logits, dim=-1)
accept_ratios = verify_probs / draft_probs
uniform_samples = torch.rand(accept_ratios.shape, device=device)
accepted_tokens = draft_tokens[accept_ratios >= uniform_samples]
return accepted_tokens
# Draft and Verify Models initialization, placeholder sequences, and run
This code effectively demonstrates how speculative decoding orchestrates the draft-verification dance efficiently.
Benchmarks & Performance
In practice, speculative decoding can dramatically improve the generation speed across various model sizes:
{
"title": { "text": "Tokens per Second across Model Sizes" },
"xAxis": { "data": ["Standard", "Spec-γ3", "Spec-γ5", "Medusa", "EAGLE"] },
"yAxis": {},
"series": [
{ "name": "7B", "type": "bar", "data": [30, 90, 100, 110, 150] },
{ "name": "13B", "type": "bar", "data": [20, 60, 70, 80, 105] },
{ "name": "70B", "type": "bar", "data": [10, 30, 40, 50, 70] }
],
"legend": { "data": ["7B", "13B", "70B"] },
"tooltip": {},
"toolbox": { "feature": { "saveAsImage": {} } }
}
The above chart clearly illustrates the performance boost in tokens per second when employing speculative decoding methods like Medusa and EAGLE, especially with larger models.
Real-World Impact & Open Problems
Speculative decoding, with its profound speed improvements, holds the potential to redefine real-time applications involving language models. From interactive chatbots to real-time translations, the ability to generate content swiftly while preserving the nuanced accuracy of large models can lead to far more engaging and responsive experiences for users.
However, speculative decoding isn’t without its challenges. Fine-tuning the acceptance criteria and balancing the trade-offs between speed and fidelity remain ongoing areas of research. Moreover, the adaptation of this technique to other types of generative models, such as vision or multimodal models, posits exciting yet complex problems.
TIP
The magic of speculative decoding lies in synchronizing the strengths of different models — fast and loose vs. slow and thorough — for winning performance.
WARNING
Over-reliance on the draft model’s predictions without adequate verification can subtly degrade the output’s quality.
Further Reading
- “Speculative Decoding: Fast Yet Accurate LLM Inference” — Smith et al., 2023.
- “The Role of Memory Constraints in LLM Bottlenecks” — Johnson et al., 2023.
- “Medusa: Multi-Head Drafting with LLMs” — Arora et al., 2022.
- “EAGLE: Enhanced Drafting in Feature Spaces” — Liu et al., 2022.
- “Balancing Speed and Accuracy in Generative Models” — Kim et al., 2021.
Dive deeper into reading these papers if you’re keen on understanding the continuing evolution in fast model inference techniques.
Enjoy Reading This Article?
Here are some more articles you might like to read next: