In the previous post, we saw that our baseline inference engine is compute-bound. This means that we are bottlenecked by the compute capability of the GPU. The main reason for this is that we recalculate all of our attention values for every single generation step.
Time for the attention equation:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
So for every step, we have the following:
- Input token at position $i$: $x_i$ (the embedding)
- Projections:
- $Q_i = x_i W_Q$ (query projection)
- $K_i = x_i W_K$ (key projection)
- $V_i = x_i W_V$ (value projection)
- Attention computation: $Q_i$ attends to all projected keys $[K_1, K_2, ..., K_{i-1}, K_i]$ and values $[V_1, V_2, ..., V_{i-1}, V_i]$
One key insight is that we only care about computing the new $Q_i$ vector for the current position, since we want to generate the next token. So $Q_i$ attends to all keys $[K_1, K_2, ..., K_{i-1}, K_i]$ and values $[V_1, V_2, ..., V_{i-1}, V_i]$. If we've already computed $K_1, K_2, ..., K_{i-1}$ and $V_1, V_2, ..., V_{i-1}$ in previous steps, we can cache these values and re-use them.
However, our naive implementation recalculates this entire attention matrix for every step. Since an LLM is usually used to generate many tokens, this is a lot of wasted computation. We can do better!
The solution is pretty simple: we just need to cache the previous keys and values. Since our rudamentary inference engine is only used once, we'll easily keep track of our KV cache by modifying our attention class:
Note that this happens on every layer, so for our
gpt2-xl model, we'll have 48 layers with 25 attention heads
each and a hidden dimension of 1600, so $1600/25 = 64$ values per head
per layer. For all layers we need to cache $2 \cdot 48 \cdot 25 \cdot 64
= 153\text{k}$ values per position. Assuming it's half precision, this
is $153\text{k} \cdot 2 = 306\text{kB}$ per position. This quickly adds
up and makes long sequences (or large batches) challenging, but at least
we're not recalculating everything from scratch every step.
[1]
As we can see, we're now able to process more requests in parallel, and the throughput per request stays relatively stable until batch size 64. We're not bottlenecked by the compute capability of the GPU anymore, but by the memory bandwidth.
Prefill and decode
One interesting aspect of caching our KV's is that we suddenly get two phases of inference: the prefill and the decode. When we receive a request, we first need to calculate (prefill) the KV's for the first $n$ tokens of the request (the user's prompt), and then we can decode the rest of the tokens. The prefill phase is what we had before and is still compute-bound, but the decode phase is now memory-bound.
During the prefill phase, we have our regular attention computation
with a causal mask (i.e. we only attend to the previous tokens). During
the decode phase, we only care about the new $Q_i$ vector, so we can
calculate attention between that one vector and the cached $K$ and $V$
vectors. This is why the code splits up the phases with
is_causal = T > 1 with the T being the
current input length. During decode we only pass the previously
generated token as input.
Notice how both phases nicely align with the metrics we care about: the prefill phase is what needs to happen before the first token is generated and affects time to first token, while the decode phase is what affects inter-token latency. If the user submits a huge prompt, the prefill phase will take longer and affect time to first token. However, since we're mostly memory-bound for decode anyways, this won't affect inter-token latency too much. [2] And prefill is compute-bound.