LLM Inference-KV Cache Quantization Scheme
How to quantize KV cache?
--
1 Background
First, two questions need to be clarified: why KV Cache needs to be introduced and why KV Cache needs to be quantified. In LLM Decoding Attention-KV Cache Int8 Quantization, these two issues are explained in more detail. To put it simply, KV Cache is introduced in engineering to solve the problem of gradually increasing calculations during the LLM (Large Language Model) inference process, and can significantly reduce the latency in the decoding stage during the LLM inference process. And generally speaking, the proportion of KV Cache GPU memory occupancy exceeds the proportion of model GPU memory occupancy, so the GPU memory revenue quantified by KV Cache is higher than the GPU memory revenue quantified by model (weight quantification, etc.).
For LLM inference, KV Cache quantization means that more KV Cache can be accommodated while the GPU memory capacity remains unchanged, and more KV Cache means that a larger batch size can be accommodated, that is, model throughput can be increased. Typically, model throughput is negatively correlated with LLM inference cost, that is, the higher the model throughput, the lower the LLM inference cost. Therefore, in LLM inference deployment, GPU memory resources are more important than computing resources. This is one of the important reasons why Nvidia launched the H200 chip with 141GB of GPU memory. The era of putting aside GPU memory to talk about computing is gradually going away.
2 KV Cache Quantization Scheme
Following the above two questions, how to quantify KV Cache, that is, how to save GPU memory. According to the idea of model quantification (weight quantification, etc.), Int8 Is All You Need is simple and crude. However, KV Cache quantification is very different from model quantification (weight quantification, etc.), and the numerical distribution of the two is very different. KV Cache quantification can easily cause model points to drop, so some solutions with higher quantification accuracy need to be adopted. In order to better quantify KV Cache, we need to start with the LLM inference process.
2.1 Inference Process
The LLM inference process is simplified as follows. Here we only focus on the calculation process before and after the MHA (Multi Head Attention) operator, and the MHA non-quantized operator takes Flash Attention FP16 as an example.
- Prefill Stage: The Prefill Input of the first inference is the input prompt. The Q, K and V required for MHA calculation are first obtained through the QKVLinear operator, and then K and V are copied to the KV Cache for use in Flash Attention FP16 calculations, and then through FFN ( Feedforward Network) to obtain the output token and complete the first inference process here.
- Decoding Stage: The Decoding Input of the second inference is the token obtained from the first inference, and the inference process is the same as the first time. Subsequently, the second inference process is repeated until the end.
2.2 KV Cache Quantization Inference Process
The LLM inference process after KV Cache quantization is simplified as follows. The MHA non-quantization operator uses Flash Attention FP16 as an example, and the MHA quantization operator uses Decoding Attention Quantization as an example. Compared with the above process, there are two differences:
- Prefill Stage: During the first inference process, QKVLinear directly performs Flash Attention FP16 calculations, and then quantifies the KV Cache. Why is the KV Cache quantization not followed by the MHA quantization operator here? The main reason is that prompt is very important in the LLM inference process. Therefore, the MHA processing of prompt uses non-quantized Flash Attention FP16 calculation to ensure prompt to the greatest output accuracy. In addition, since the prefill stage is only executed once during the entire LLM inference process, it does not affect the KV Cache quantization of the subsequent decoding stage.
- Decoding Stage: During the second inference process, KV Cache copy is replaced by KV Cache quantization, and MHA processing uses the quantization operator Decoding Attention Quantization. A solution with higher quantization accuracy is required here. Please refer to LLM Decoding Attention-KV Cache Int8 Quantization and LLM Decoding Attention-KV Cache FP8 Quantization. Subsequent inference also uses the MHA quantization operator.
3 Other
3.1 MHA Inference and KV Cache Quantization
This article provides a high-precision KV Cache quantization scheme. For more information about MHA inference and KV Cache quantization, please refer to:
- Flash Attention-Inference Performance Exploring
- Nvidia CUDA Core-LLM Decoding Attention Inference Optimization
- LLM Decoding Attention-KV Cache Int8 Quantization
- LLM Decoding Attention-KV Cache FP8 Quantization
3.2 KV Cache Quantization
Int4 quantization optimization.