duahaulaplanh 🍉
Overview
May 17, 2026
9 min read

Why do we need KV caching?

Warning (AI Transparency)

All the ideas, notes, and technical content are mine. I use AI (GPT-5.5 from
Codex) as a writing assistant and to help generate Manim videos from my notes.
The Manim animation source code is in
vietfood/manim-lib .

Why do we need KV caching? If you already know how Transformer If you read Vietnamese, you may want to read my previous Transformer posts first: Một chút về Transformers (phần 1) and Một chút về Transformers (phần 2) . inference works, you probably know the short answer. But just in case you don’t, let’s start with a small example:

  1. Suppose the model reads the prompt:
A B C D
  1. It will predict the next token. Suppose that token is:
E
  1. Now the model needs to continue from:
A B C D E

A naive implementation would process the whole sequence again. But A, B, C, and D have not changed. In a causal Transformer, future tokens are allowed to look backward, but past tokens are not allowed to look forward. So the key and value tensors for A, B, C, and D can be reused.

Those reusable tensors are the KV cache During decoding, compute the key and value vectors for each token once, store them, and reuse them when later tokens attend to the past. .

1. The part of attention that matters

For this post, we only need one fact about attention: a token reads previous tokens by comparing a query against their keys, then mixing their values.

Let token ii have embedding xiRd\mathbf{x}_i \in \mathbb{R}^{d}. More precisely, if we are not in the first layer, xi\mathbf{x}_i is the token’s hidden state at that layer. An attention layer projects it into three vectors:

qi=xiWQ,ki=xiWK,vi=xiWV\mathbf{q}_i = \mathbf{x}_i\mathbf{W}^Q,\qquad \mathbf{k}_i = \mathbf{x}_i\mathbf{W}^K,\qquad \mathbf{v}_i = \mathbf{x}_i\mathbf{W}^V

These names are useful if we read them literally:

  • Query: what the current token is looking for.
  • Key: what a token offers to be matched against.
  • Value: the information that gets copied into the output after the match.

For token ii, attention scores previous token jj with a scaled dot product:

score(qi,kj)=qikjdk\text{score}(\mathbf{q}_i, \mathbf{k}_j) = \frac{\mathbf{q}_i \cdot \mathbf{k}_j}{\sqrt{d_k}}

Because this is a causal language model, token ii can only attend to tokens at positions jij \leq i. After softmax, the attention output is:

ai=jiαijvj,αij=exp(score(qi,kj))kiexp(score(qi,kk))\mathbf{a}_i = \sum_{j \leq i} \alpha_{ij}\mathbf{v}_j, \qquad \alpha_{ij} = \frac{\exp(\text{score}(\mathbf{q}_i, \mathbf{k}_j))} {\sum_{k \leq i}\exp(\text{score}(\mathbf{q}_i, \mathbf{k}_k))}

So the execution order is:

current query -> compare with previous keys -> weight previous values -> output

In matrix form, for a sequence XRN×d\mathbf{X} \in \mathbb{R}^{N \times d}:

Q=XWQ,K=XWK,V=XWV\mathbf{Q} = \mathbf{X}\mathbf{W}^Q,\qquad \mathbf{K} = \mathbf{X}\mathbf{W}^K,\qquad \mathbf{V} = \mathbf{X}\mathbf{W}^V

and Causal masking (or the mask\text{mask} function) tells the softmax to ignore future tokens. In practice, we add a very negative value such as -\infty to masked positions before softmax, so their probability becomes 0. :

Attention(Q,K,V)=softmax(mask(QKTdk))V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\text{mask}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\right)\mathbf{V}
A normal attention pass for four tokens: A, B, C, and D. The layer builds Q, K, and V for every token, then computes causal attention.

This full-sequence computation is reasonable when the model first reads the prompt. The problem starts after the prompt, when generation becomes token-by-token.

2. The waste during decoding

LLMs usually have two inference phases:

PhaseInput shapeWhat the model does
PrefillMany prompt tokens at onceProcess the prompt and prepare internal state.
DecodeOne new token per sequence at a timeExtend the sequence by one token, then repeat.

During prefill, the model can process the prompt in parallel. If the prompt has 1,000 tokens, a large matrix operation can build attention states for those 1,000 tokens.

During decode, only the newest token is actually new. If the current sequence is A B C D and the next token is E, the model does not need new outputs for A, B, C, and D. It only needs the output for E.

Without caching, the model still rebuilds old projections:

A B C D -> build K/V for A, B, C, D
A B C D E -> build K/V for A, B, C, D, E again
A B C D E F -> build K/V for A, B, C, D, E, F again

The repeated work gets worse as the generated sequence grows.

The naive next-token pass recomputes projections and attention for old tokens, even though only token E is new.

The key observation is not merely that recomputation is expensive. The key observation is that the recomputation is unnecessary. Because of causal masking, old tokens cannot incorporate information from new tokens. Their keys and values at each layer are already fixed.

3. Why the cache stores K and V

Now look at the exact work needed for the new token E.

E needs its own query:

qE\mathbf{q}_E

Then E compares that query with all available keys:

qEkA,qEkB,qEkC,qEkD,qEkE\mathbf{q}_E \cdot \mathbf{k}_A,\quad \mathbf{q}_E \cdot \mathbf{k}_B,\quad \mathbf{q}_E \cdot \mathbf{k}_C,\quad \mathbf{q}_E \cdot \mathbf{k}_D,\quad \mathbf{q}_E \cdot \mathbf{k}_E

Finally, E mixes the values:

aE=αEAvA+αEBvB+αECvC+αEDvD+αEEvE\mathbf{a}_E = \alpha_{EA}\mathbf{v}_A + \alpha_{EB}\mathbf{v}_B + \alpha_{EC}\mathbf{v}_C + \alpha_{ED}\mathbf{v}_D + \alpha_{EE}\mathbf{v}_E

So the new token needs old keys and old values. It does not need old queries, because we are not recomputing the attention outputs for A, B, C, and D.

That is the reason for the name KV cache A KV cache stores the key and value tensors produced for previous tokens at each Transformer layer, so future tokens can attend to them without recomputing them. (not QK cache or QKV cache).

When token E arrives, the model only computes:

q_E, k_E, v_E

Then it attends with:

q_E against [cached k_A, cached k_B, cached k_C, cached k_D, k_E]

and mixes:

[cached v_A, cached v_B, cached v_C, cached v_D, v_E]
Only token E is new. The model can reuse the cached K and V tensors for A, B, C, and D, then append the new K and V for E.

4. The decode loop with a KV cache

With KV caching, prefill and decode have different jobs.

During prefill, the model reads the prompt and fills the cache:

prompt tokens -> create and fill K/V tensors for every layer -> cache them

During decode, each step does a small update:

new token -> new K/V tensors -> append to cache
new query -> attend over cached K/V -> compute next-token logits

Written as simplified PyTorch-like code:

past_k, past_v = kv_cache
q_new = x_new @ W_Q
k_new = x_new @ W_K
v_new = x_new @ W_V
k_all = torch.cat([past_k, k_new], dim=0)
v_all = torch.cat([past_v, v_new], dim=0)
output_new = softmax(q_new @ k_all.T / sqrt(d_k)) @ v_all
kv_cache = (k_all, v_all)

5. What KV caching is good for, and what it is not good for

KV caching saves the work of rebuilding old key and value tensors. Each token’s K/V tensors are computed once per layer and then reused for later tokens.

The takeaway: KV caching avoids recomputing what causal attention already made fixed.

The clean way to state the advantage is that KV caching reduces repeated computation during decoding But it increases memory usage because the model stores K/V tensors for previous tokens. . It does not make long-context attention free. The newest token still needs to compare its query with previous keys, and it still needs to read previous values. If the context is long, the cache can be large and expensive to read. This is why serving systems use techniques such as PagedAttention If you know about LLM inference, you may know about the vLLM framework. One of vLLM’s core ideas is PagedAttention, which manages KV cache memory in fixed-size blocks. You can read more in Efficient Memory Management for Large Language Model Serving with PagedAttention and vLLM’s documentation . to manage KV cache memory more efficiently.

For every generated token, every layer stores key and value tensors. The cache grows with:

  • Batch size: We will store previous tokens’ K and V tensors for each batch, so the memory usage is basically linear with the batch size.
  • Sequence length: Same as batch size, we will store previous tokens’ K and V tensors, so the memory usage is basically linear with the sequence length.
  • Number of layers: Each Transformer layer has its own attention module, so each layer has its own K and V tensors.
  • Number of KV heads and head dimension: Each KV head stores its own K and V tensors. In grouped-query attention, the number of KV heads can be smaller than the number of query heads.
  • Numeric precision: Each element in the K and V tensors takes a fixed number of bytes. This is why KV cache quantization can reduce memory usage.

A rough shape for the cache is:

KV cache shape2×B×L×Hkv×T×dhead\text{KV cache shape} \approx 2 \times B \times L \times H_{kv} \times T \times d_{head}

where:

  • 22 is for K and V
  • BB is batch size
  • LL is the number of layers
  • HkvH_{kv} is the number of key/value heads
  • TT is sequence length
  • dheadd_{head} is the head dimension

This is why serving long-context models is often memory-limited. The model weights may fit on the GPU, but the cache can still become the bottleneck Many inference optimizations are different ways of managing this cost: paged KV cache, grouped-query attention, multi-query attention, quantized KV cache, sliding-window attention, and cache eviction. They differ in implementation, but they all start from the same fact: past keys and values are useful enough to store. as batch size and context length grow You can also use tools such as KV Cache Calculator by gaunernst (the GOAT) to sanity-check these numbers. .

Example (Example: Calculating KV cache memory usage)

Before we finish, let’s solve a simple problem on KV caching. Suppose we have a model with 10 layers, 16 KV heads, head dimension 64, sequence length 1024, batch size 1, and we will use float32 precision. How much memory will the KV cache take? Use the formula above, we have:

KV cache shape=2×1×10×16×1024×64=20971520\text{KV cache shape} = 2 \times 1 \times 10 \times 16 \times 1024 \times 64 = 20971520

So we need to store 20971520 elements where each element is 4 bytes. So the total memory usage is 20971520 * 4 = 83886080 bytes, which is about 83.89 MB.

Okay cool, for a small model like that, we need roughly 80 MiB of memory for the KV cache. But what if we use a real-world model like Llama 3.1 8B ? Its config is still simple enough to calculate by hand:

  • Number of layers: 32.
  • Number of KV heads: 8.
  • Precision: bfloat16 (2 bytes).
  • Sequence length: 2048.
  • Head dimension: 128.
  • Batch size: 1.

So the memory usage is:

KV cache shape=2×1×32×8×2048×128=134217728\text{KV cache shape} = 2 \times 1 \times 32 \times 8 \times 2048 \times 128 = 134217728

So we need to store 134217728 elements where each element is 2 bytes. The total memory usage is 134217728 * 2 = 268435456 bytes, which is about 268.44 MB.

Llama 3.1 8B supports much longer contexts than 2048 tokens. If we scale the same setup to 100k tokens, the KV cache becomes roughly:

256 MiB * (100000 / 2048) = 12500 MiB = 12.21 GiB

This is only for batch size 1. Larger batches multiply the cache size again.