LLM Decoding Attention-KV Cache Int8 Quantization

How to quantize KV cache to int8?

Bruce-Lee-LY
7 min readNov 8, 2023

1 Background

When the traditional CNN algorithm handles image detection or classification tasks, the inputs of the previous and next frames are independent of each other, while when LLM (Large Language Model) infers, the input of the Nth inference is composed of the first input and the previous N-1th outputs, so the output of the Nth inference depends on the first input and the output of the previous N-1th. It can be seen that during LLM inference, the input tokens for subsequent inference have become longer, resulting in a gradual increase in the calculation amount and delay in subsequent inference.

In view of this, KV cache was introduced in the project to solve the problem of gradually increasing calculation amount during the LLM inference process. To put it simply, during the inference process, the intermediate calculation results K and V of the N-1th inference are saved to the KV cache. During the Nth inference, there is no need to calculate K and V that are input repeatedly for the N-1th time. Read directly from the KV cache. In this case, except for the first inference, the input sequence length of Q in the subsequent inference process is always 1. Accordingly, LLM inference is divided into two stages, namely the Prefill stage (first inference) and the Decoding stage (except for the first inference).

2 KV Cache Memory Usage

Taking llama 7B deployed on RTX A6000 48G as an example, the FP16 model occupies about 14G of GPU memory, and the intermediate parameters for inference occupy about 2G of GPU memory. Almost all of the remaining 32G of GPU memory must be allocated to KV cache to reduce inference latency and increase model throughput. Typically, model throughput is negatively correlated with LLM inference cost, that is, the higher the model throughput, the lower the LLM inference cost.

It can be found that when deploying LLM inference, GPU memory resources are more important than computing resources. On the one hand, GPU memory resources determine the feasibility of model deployment, and on the other hand, they seriously affect the cost of LLM inference. Therefore, when deploying LLM inference, some GPU cards with larger GPU memory and lower computing power are very cost-effective.

Generally speaking, KV cache is saved in FP16 format and takes up a larger proportion of GPU memory. If the KV cache storage format is modified to int8, KV cache storing the same sequence length can save nearly half of the GPU memory, which is equivalent to doubling the available GPU memory of KV cache in disguise. In this case, more requested KV cache can be saved into the GPU memory, thereby greatly increasing model throughput and reducing inference costs.

However, KV cache quantification is very different from model quantification (weight quantification, etc.), and the numerical distribution of the two is very different. KV cache quantification can easily lead to model drop points, so some solutions with higher quantification accuracy need to be adopted.

3 KV Cache Int8 Quantization

Based on Nvidia CUDA Core-LLM Decoding Attention Inference Optimization, this article continues to do int8 quantization optimization, while supporting GQA (Group Query Attention)/MQA (Multi Query Attention) and ALiBi (Attention with Linear Biases) inference scenarios. The code is open source in flash_attention_inference.

3.1 Int8 Quantization

Usually, there are the following 4 types of int8 quantization schemes for KV cache:

  • Per Tensor: Uniform quantification of K or V of all tokens
  • Per Token: Uniform quantization of K or V of a certain token
  • Per Head: Uniform quantization of K or V of a certain head of a certain token
  • Per Group: Group quantization of K or V of a certain head of a certain token

The Per Tensor solution has high GPU memory benefits, but the accuracy may drop significantly; the Per Group solution has high accuracy, but the scale storage capacity is large, and the GPU memory benefits are not high. Here, in order to take into account both quantization accuracy and memory gain, Per Head’s quantization scheme was chosen. The quantization code is as follows.

#pragma unroll
for (size_t base_h_idx = warp_id * groups_per_warp; base_h_idx < params.h_k; base_h_idx += groups_per_block) {
size_t h_idx = base_h_idx + group_id;
half RK[thread_elem_nums];
half RV[thread_elem_nums];

int8_t RQK[thread_elem_nums];
int8_t RQV[thread_elem_nums];

half k_scale = 0.0;
half v_scale = 0.0;

if (h_idx < params.h_k) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RK[i * thread_copy_elem_nums]) =
*(int4 *)(&params.k_ptr[binfo.k_offset(params.k_row_stride, h_idx, params.k_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]);
*(int4 *)(&RV[i * thread_copy_elem_nums]) =
*(int4 *)(&params.v_ptr[binfo.k_offset(params.v_row_stride, h_idx, params.v_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]);
}

#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
k_scale = (k_scale > __habs(RK[i])) ? k_scale : __habs(RK[i]);
v_scale = (v_scale > __habs(RV[i])) ? v_scale : __habs(RV[i]);
}
}

#pragma unroll
for (size_t i = group_size / 2; i >= 1; i /= 2) {
k_scale = fmaxf(k_scale, __shfl_xor_sync(shfl_mask, k_scale, i));
v_scale = fmaxf(v_scale, __shfl_xor_sync(shfl_mask, v_scale, i));
}

if (h_idx < params.h_k) {
k_scale /= max_int8;
v_scale /= max_int8;
k_scale = (k_scale > min_scale) ? k_scale : min_scale;
v_scale = (v_scale > min_scale) ? v_scale : min_scale;

#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
RQK[i] = static_cast<int8_t>(__half2short_rn(RK[i] / k_scale));
RQV[i] = static_cast<int8_t>(__half2short_rn(RV[i] / v_scale));
}

#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int2 *)(&params
.k_int8_ptr[binfo.k_offset(params.k_row_stride, h_idx, params.k_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]) =
*(int2 *)(&RQK[i * thread_copy_elem_nums]);
*(int2 *)(&params
.v_int8_ptr[binfo.k_offset(params.v_row_stride, h_idx, params.v_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]) =
*(int2 *)(&RQV[i * thread_copy_elem_nums]);
}

params.k_scale_ptr[binfo.k_scale_offset(params.h_k, h_idx)] = k_scale;
params.v_scale_ptr[binfo.k_scale_offset(params.h_k, h_idx)] = v_scale;
}
}

3.2 Int8 Dequantization and Decoding Attention

Here, int8 dequantization is integrated into Decoding Attention. The code is as follows.

    // S = Q * K^T
half RQ[thread_elem_nums];

#pragma unroll
for (size_t i = 0; i < thread_q_iters; ++i) {
*(int4 *)(&RQ[i * thread_copy_q_elem_nums]) = *(int4 *)(&params.q_ptr[binfo.q_offset(
params.q_row_stride, params.q_head_stride, (i * group_size + group_lane_id) * thread_copy_q_elem_nums)]);
}

extern __shared__ float S_smem[];
float S_max = -std::numeric_limits<float>::max();

#pragma unroll
for (size_t base_seqlen_k = warp_id * groups_per_warp; base_seqlen_k < binfo.actual_seqlen_k;
base_seqlen_k += groups_per_block) {
size_t seqlen_k = base_seqlen_k + group_id;
int8_t RQK[thread_elem_nums];
float RK_scale = 0.0;

float tmp = 0.0;
if (seqlen_k >= binfo.actual_seqlen_k) {
memset(RQK, 0, sizeof(RQK));
} else {
#pragma unroll
for (size_t i = 0; i < thread_k_iters; ++i) {
*(int4 *)(&RQK[i * thread_copy_k_elem_nums]) = *(int4 *)(&params.k_int8_ptr[binfo.k_offset(
seqlen_k, params.k_row_stride, params.k_head_stride,
(i * group_size + group_lane_id) * thread_copy_k_elem_nums)]);
}

RK_scale = __half2float(params.k_scale_ptr[binfo.k_scale_offset(seqlen_k, params.h_k)]);

#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
tmp += (__half2float(RQ[i]) * (static_cast<float>(RQK[i]) * RK_scale));
}
}

#pragma unroll
for (size_t i = group_size / 2; i >= 1; i /= 2) {
tmp += __shfl_xor_sync(shfl_mask, tmp, i);
}

if (group_lane_id == 0 && seqlen_k < binfo.actual_seqlen_k) {
tmp *= params.scale_softmax;

if (IsAlibi) {
tmp += (binfo.h_slope * (static_cast<int>(seqlen_k) - binfo.actual_seqlen_q - binfo.row_shift));
}

S_smem[seqlen_k] = tmp;
S_max = fmaxf(tmp, S_max);
}
}

// P = Softmax(S)
__shared__ float softmax_smem[warps_per_block];

#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}

if (lane_id == 0) {
softmax_smem[warp_id] = S_max;
}

__syncthreads();

if (lane_id < warps_per_block) {
S_max = softmax_smem[lane_id];
} else {
S_max = -std::numeric_limits<float>::max();
}

#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}

S_max = __shfl_sync(shfl_mask, S_max, 0);

float exp_sum = 0.0;
#pragma unroll
for (size_t seqlen_k = threadIdx.x; seqlen_k < binfo.actual_seqlen_k; seqlen_k += threads_per_block) {
S_smem[seqlen_k] -= S_max;
S_smem[seqlen_k] = exp(S_smem[seqlen_k]);
exp_sum += S_smem[seqlen_k];
}

#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}

if (lane_id == 0) {
softmax_smem[warp_id] = exp_sum;
}

__syncthreads();

if (lane_id < warps_per_block) {
exp_sum = softmax_smem[lane_id];
}

#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}
exp_sum = __shfl_sync(shfl_mask, exp_sum, 0);

#pragma unroll
for (size_t seqlen_k = threadIdx.x; seqlen_k < binfo.actual_seqlen_k; seqlen_k += threads_per_block) {
S_smem[seqlen_k] /= exp_sum;
}

__syncthreads();

// O = P * V
int8_t RQV[thread_elem_nums];
float RO[thread_elem_nums];

memset(RO, 0, sizeof(RO));

#pragma unroll
for (size_t base_seqlen_k = warp_id * groups_per_warp; base_seqlen_k < binfo.actual_seqlen_k;
base_seqlen_k += groups_per_block) {
size_t seqlen_k = base_seqlen_k + group_id;

if (seqlen_k < binfo.actual_seqlen_k) {
#pragma unroll
for (size_t i = 0; i < thread_k_iters; ++i) {
*(int4 *)(&RQV[i * thread_copy_k_elem_nums]) = *(int4 *)(&params.v_int8_ptr[binfo.k_offset(
seqlen_k, params.v_row_stride, params.v_head_stride,
(i * group_size + group_lane_id) * thread_copy_k_elem_nums)]);
}

float RV_scale = __half2float(params.v_scale_ptr[binfo.k_scale_offset(seqlen_k, params.h_k)]);

#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
RO[i] += (S_smem[seqlen_k] * (static_cast<float>(RQV[i]) * RV_scale));
}
}
}

#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
#pragma unroll
for (size_t j = group_size; j <= warp_size / 2; j *= 2) {
RO[i] += __shfl_xor_sync(shfl_mask, RO[i], j);
}
}

__syncthreads();

#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
S_smem[i] = 0.0;
}

__syncthreads();

if (lane_id < group_size) {
#pragma unroll
for (size_t i = 0; i < thread_k_iters; ++i) {
#pragma unroll
for (size_t j = 0; j < thread_copy_k_elem_nums; ++j) {
atomicAdd(S_smem + (i * group_size + lane_id) * thread_copy_k_elem_nums + j,
RO[i * thread_copy_k_elem_nums + j]);
}
}
}

__syncthreads();

#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
params.o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2half(S_smem[i]);
}

3.3 Latency

Due to the addition of the int8 quantization operator and the integration of the dequantization operation in Decoding Attention, the effective TFLOPS of Decoding Attention is reduced and the latency is increased. However, compared with the GPU memory that has nearly doubled the KV cache, the increase in latency of Decoding Attention is completely acceptable.

(1) Test Conditions

  • MHA: O = Softmax(Q * K^T) * V
  • CUDA: 11.8
  • GPU: RTX3090
  • Flash Attention: v1.0.9
  • Flash Attention v2: v2.1.0
  • Cutlass: v3.1.0
  • Head Num: 32
  • Head Dim: 128

(2)Decoding Seq Len

  • Batch Size: 128
  • Seq Q: 1
  • Seq K: Seq Len

4 Other

FP8, int4 quantization optimization.

--

--