Warning (AI Transparency)
All the ideas, notes, and technical content are mine. I use AI (GPT-5.5 from
Codex) as a writing assistant and to help generate Manim videos from my notes.
The Manim animation source code is in
vietfood/manim-lib .
Why do we need KV caching? If you already know how Transformer If you read Vietnamese, you may want to read my previous Transformer posts first: Một chút về Transformers (phần 1) and Một chút về Transformers (phần 2) . inference works, you probably know the short answer. But just in case you don’t, let’s start with a small example:
- Suppose the model reads the prompt:
A B C D- It will predict the next token. Suppose that token is:
E- Now the model needs to continue from:
A B C D EA naive implementation would process the whole sequence again. But A, B, C, and D have not changed. In a causal Transformer, future tokens are allowed to look backward, but past tokens are not allowed to look forward. So the key and value tensors for A, B, C, and D can be reused.
Those reusable tensors are the KV cache During decoding, compute the key and value vectors for each token once, store them, and reuse them when later tokens attend to the past. .
1. The part of attention that matters
For this post, we only need one fact about attention: a token reads previous tokens by comparing a query against their keys, then mixing their values.
Let token have embedding . More precisely, if we are not in the first layer, is the token’s hidden state at that layer. An attention layer projects it into three vectors:
These names are useful if we read them literally:
- Query: what the current token is looking for.
- Key: what a token offers to be matched against.
- Value: the information that gets copied into the output after the match.
For token , attention scores previous token with a scaled dot product:
Because this is a causal language model, token can only attend to tokens at positions . After softmax, the attention output is:
So the execution order is:
current query -> compare with previous keys -> weight previous values -> outputIn matrix form, for a sequence :
and Causal masking (or the function) tells the softmax to ignore future tokens. In practice, we add a very negative value such as to masked positions before softmax, so their probability becomes 0. :
This full-sequence computation is reasonable when the model first reads the prompt. The problem starts after the prompt, when generation becomes token-by-token.
2. The waste during decoding
LLMs usually have two inference phases:
| Phase | Input shape | What the model does |
|---|---|---|
| Prefill | Many prompt tokens at once | Process the prompt and prepare internal state. |
| Decode | One new token per sequence at a time | Extend the sequence by one token, then repeat. |
During prefill, the model can process the prompt in parallel. If the prompt has 1,000 tokens, a large matrix operation can build attention states for those 1,000 tokens.
During decode, only the newest token is actually new. If the current sequence is A B C D and the next token is E, the model does not need new outputs for A, B, C, and D. It only needs the output for E.
Without caching, the model still rebuilds old projections:
A B C D -> build K/V for A, B, C, DA B C D E -> build K/V for A, B, C, D, E againA B C D E F -> build K/V for A, B, C, D, E, F againThe repeated work gets worse as the generated sequence grows.
The key observation is not merely that recomputation is expensive. The key observation is that the recomputation is unnecessary. Because of causal masking, old tokens cannot incorporate information from new tokens. Their keys and values at each layer are already fixed.
3. Why the cache stores K and V
Now look at the exact work needed for the new token E.
E needs its own query:
Then E compares that query with all available keys:
Finally, E mixes the values:
So the new token needs old keys and old values. It does not need old queries, because we are not recomputing the attention outputs for A, B, C, and D.
That is the reason for the name KV cache A KV cache stores the key and value tensors produced for previous tokens at each Transformer layer, so future tokens can attend to them without recomputing them. (not QK cache or QKV cache).
When token E arrives, the model only computes:
q_E, k_E, v_EThen it attends with:
q_E against [cached k_A, cached k_B, cached k_C, cached k_D, k_E]and mixes:
[cached v_A, cached v_B, cached v_C, cached v_D, v_E]4. The decode loop with a KV cache
With KV caching, prefill and decode have different jobs.
During prefill, the model reads the prompt and fills the cache:
prompt tokens -> create and fill K/V tensors for every layer -> cache themDuring decode, each step does a small update:
new token -> new K/V tensors -> append to cachenew query -> attend over cached K/V -> compute next-token logitsWritten as simplified PyTorch-like code:
past_k, past_v = kv_cache
q_new = x_new @ W_Qk_new = x_new @ W_Kv_new = x_new @ W_V
k_all = torch.cat([past_k, k_new], dim=0)v_all = torch.cat([past_v, v_new], dim=0)
output_new = softmax(q_new @ k_all.T / sqrt(d_k)) @ v_all
kv_cache = (k_all, v_all)5. What KV caching is good for, and what it is not good for
KV caching saves the work of rebuilding old key and value tensors. Each token’s K/V tensors are computed once per layer and then reused for later tokens.
The clean way to state the advantage is that KV caching reduces repeated computation during decoding But it increases memory usage because the model stores K/V tensors for previous tokens. . It does not make long-context attention free. The newest token still needs to compare its query with previous keys, and it still needs to read previous values. If the context is long, the cache can be large and expensive to read. This is why serving systems use techniques such as PagedAttention If you know about LLM inference, you may know about the vLLM framework. One of vLLM’s core ideas is PagedAttention, which manages KV cache memory in fixed-size blocks. You can read more in Efficient Memory Management for Large Language Model Serving with PagedAttention and vLLM’s documentation . to manage KV cache memory more efficiently.
For every generated token, every layer stores key and value tensors. The cache grows with:
- Batch size: We will store previous tokens’ K and V tensors for each batch, so the memory usage is basically linear with the batch size.
- Sequence length: Same as batch size, we will store previous tokens’ K and V tensors, so the memory usage is basically linear with the sequence length.
- Number of layers: Each Transformer layer has its own attention module, so each layer has its own K and V tensors.
- Number of KV heads and head dimension: Each KV head stores its own K and V tensors. In grouped-query attention, the number of KV heads can be smaller than the number of query heads.
- Numeric precision: Each element in the K and V tensors takes a fixed number of bytes. This is why KV cache quantization can reduce memory usage.
A rough shape for the cache is:
where:
- is for K and V
- is batch size
- is the number of layers
- is the number of key/value heads
- is sequence length
- is the head dimension
This is why serving long-context models is often memory-limited. The model weights may fit on the GPU, but the cache can still become the bottleneck Many inference optimizations are different ways of managing this cost: paged KV cache, grouped-query attention, multi-query attention, quantized KV cache, sliding-window attention, and cache eviction. They differ in implementation, but they all start from the same fact: past keys and values are useful enough to store. as batch size and context length grow You can also use tools such as KV Cache Calculator by gaunernst (the GOAT) to sanity-check these numbers. .
Example (Example: Calculating KV cache memory usage)
Before we finish, let’s solve a simple problem on KV caching. Suppose we have a model with 10 layers, 16 KV heads, head dimension 64, sequence length 1024, batch size 1, and we will use float32 precision. How much memory will the KV cache take? Use the formula above, we have:
So we need to store 20971520 elements where each element is 4 bytes. So the total memory usage is 20971520 * 4 = 83886080 bytes, which is about 83.89 MB.
Okay cool, for a small model like that, we need roughly 80 MiB of memory for the KV cache. But what if we use a real-world model like Llama 3.1 8B ? Its config is still simple enough to calculate by hand:
- Number of layers: 32.
- Number of KV heads: 8.
- Precision:
bfloat16(2 bytes). - Sequence length: 2048.
- Head dimension: 128.
- Batch size: 1.
So the memory usage is:
So we need to store 134217728 elements where each element is 2 bytes. The total memory usage is 134217728 * 2 = 268435456 bytes, which is about 268.44 MB.
Llama 3.1 8B supports much longer contexts than 2048 tokens. If we scale the same setup to 100k tokens, the KV cache becomes roughly:
256 MiB * (100000 / 2048) = 12500 MiB = 12.21 GiBThis is only for batch size 1. Larger batches multiply the cache size again.