Qianliang's blog

Learning Transformer and KV Cache as an AI NewBie

Transformer

Overall, the Transformer consists of two parts - encoder and decoder.

Encoder

Each input will go through this simplified process:

(1) Input sequence (prompt) maps to an embedding (just a vector).

Denote sequence length as P, embedding dimension as d, then the input matrix is P x d

(2) Feeds into self-attention layer: basically three components: Q, K, V matrices, where

Q=XWQ,K=XWK,V=XWV, with dq=dk=d/h, h is the number of multi-head (will be discussed after)

All parameter matrices W will be learned during training.

Then the Attention can be computed as:

Attention(Q,K,V)=softmax(QKT/dk)V

Intuition: Each token computes the correlation with others, then uses softmax to get the weighted result.

(3) Extention of self-attention layer: Multi-Head Attention.

With h heads, we will have h {WiQ,WiK,WiV} parameter matrices, then concatenate the attention outputs from each head.

Intuition: By using multiple heads, we expand the learning space, which will give the model more capability.

(4) Feed-Forward Network (FFN): Skip.

Decoder

The architecture of decoder is almost the same as the encoder, except for the extra encoder-decoder attention layer. This layer will also compute the attention with QKV matrices, but substitute the KV from the encoder, where is the KV Cache will take effect.

Notice: there is a causal mask at the decoder, which will let the attention compute at token t only relate to token 1, 2, ..., t-1. The mathematical form is:

[S(QKT)V]t=j=1tSt(qt·kj)vj (*)

Dimension is P. Each row represents a token t.

KV Cache

In LLM world, there are bascially two stages for processing user prompts:

Prefill

The prefill stage will take the entire user's prompt as input, then generate the first token of response as output in one step.

Note: We have all input tokens at one time, so according to the formula (*), we can compute all the KVs for prompt with length P and then cache them.

Decode

The decode stage will take one token as input (using the output of prefill stage as original token at the beginning), then generate the next token as output iteratively. So this process is also called auto-regressive.

Note: Here each token at position t (in decode stage) will also needs to compute the QKV at attention layer. Based on the formula (*), the qt,kt,vt is newly generated with token t, but all the previous (P + t - 1) KVs are needed to compute the attention.

So here comes the KV Cache. We can cache the newly generated KV for token t at decode stage, and all KVs in prefill stage. Then for the next token, it can reuse all the previous KVs to save the computation.

Problem with KV Cache

Suppose an extremely long context as the prompt, given the current strategy of caching KVs, large amount of memory will be consumed to store the KV.

In vLLm paper, under the setting of NVIDIA A100 GPU with 40GB HBM serving as 13B LLM, 65% memory will be used for model parameters, and >30% for KV Cache.

Optimizations

Many methods have been proposed to save the memory usage of KV Cache. Here are some popular ones:

MQA (Multi-Query Attention)

The idea is intuitive: as each head will generate a KV pair, we can use the same KV for all heads in one layer. Thus the memory usage of MQA is 1/h of MHA. Used by PaLM, Gemini.

GQA (Group-Query Attention)

Hybrid of MQA and MHA: make all heads in g groups, all heads in the same group will share the same KV. (g can be divided by h)

The idea is to find the trade-off between memory usage and model capability.

Used by DeepSeek-V1, LLaMa2/3.

MLA (Multi-Head Latent Attention)

The MLA can be seen as a more generalized version of GQA: instead of directly partition the KVs into g groups, it uses a projection matrix to do a low-rank compression (input ht multiples the projection matrix to get a latent context ctKV). Then the latent context can be used to generate the KVs for multi-heads.

So here we just need to cache the latent ctKV instead of the original KV.

Comparison of KV Cache # per Token:

Here nh is number of head, dh is dimension of each head, l is layer number.

Proposed by DeepSeek-V2.