LLM Decoding Attention-KV Cache FP8 Quantization
How to quantize KV cache to FP8?
--
1 Background
On the basis of LLM Decoding Attention-KV Cache Int8 Quantization and LLM Decoding Attention-KV Cache Int8 Quantization, continue to try a better KV Cache quantification method. The goal is still to obtain more memory benefits and improve quantization accuracy.
The int8 Per Head quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. The FP8 data format introduced by Nvidia retains 2~3 mantissa bits and can directly convert FP16 and FP8 to each other, which not only simplifies the quantization and dequantization operations, but also does not require additional scale GPU memory storage.
2 FP8 Data Format
Nvidia introduced two FP8 data types:
- E5M2: Has 5 exponent bits, 2 mantissa bits and 1 sign bit
- E4M3: Has 4 exponent bits, 3 mantissa bits and 1 sign bit
E4M3 supports calculations with smaller dynamic range and higher precision, while E5M2 provides wider dynamic range and lower precision. Compared with FP16, FP8 can reduce the required data storage space by half.
3 KV Cache FP8 Quantization
Based on Nvidia CUDA Core-LLM Decoding Attention Inference Optimization, this article continues to do FP8 quantization optimization, while supporting GQA (Group Query Attention)/MQA (Multi Query Attention) and ALiBi (Attention with Linear Biases) inference scenarios. The code is open source in flash_attention_inference.
3.1 FP8 Quantization
Since FP8 quantization and dequantization are only direct conversions of the two data types FP16 and FP8, they can be directly integrated into Decoding Attention. The code is as follows.
// S = Q * K^T
half RQ[thread_elem_nums];
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RQ[i * thread_copy_elem_nums]) = *(int4 *)(¶ms.q_ptr[binfo.q_offset(
params.q_row_stride, params.q_head_stride, (i * group_size + group_lane_id) * thread_copy_elem_nums)]);
}
extern __shared__ float S_smem[];
float S_max = -std::numeric_limits<float>::max();
#pragma unroll
for (size_t base_seqlen_k = warp_id * groups_per_warp; base_seqlen_k < binfo.actual_seqlen_k;
base_seqlen_k += groups_per_block) {
size_t seqlen_k = base_seqlen_k + group_id;
half RK[thread_elem_nums];
fp8_t RQK[thread_elem_nums];
float tmp = 0.0;
if (seqlen_k >= binfo.actual_seqlen_k) {
memset(RQK, 0, sizeof(RQK));
} else {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RK[i * thread_copy_elem_nums]) =
*(int4 *)(¶ms.k_ptr[binfo.k_offset(seqlen_k, params.k_row_stride, params.k_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]);
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
RQK[i] = static_cast<fp8_t>(RK[i]);
tmp += (__half2float(RQ[i]) * static_cast<float>(RQK[i]));
}
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int2 *)(¶ms.k_fp8_ptr[binfo.k_offset(seqlen_k, params.k_row_stride, params.k_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]) =
*(int2 *)(&RQK[i * thread_copy_elem_nums]);
}
}
#pragma unroll
for (size_t i = group_size / 2; i >= 1; i /= 2) {
tmp += __shfl_xor_sync(shfl_mask, tmp, i);
}
if (group_lane_id == 0 && seqlen_k < binfo.actual_seqlen_k) {
tmp *= params.scale_softmax;
if (IsAlibi) {
tmp += (binfo.h_slope * (static_cast<int>(seqlen_k) - binfo.actual_seqlen_q - binfo.row_shift));
}
S_smem[seqlen_k] = tmp;
S_max = fmaxf(tmp, S_max);
}
}
// P = Softmax(S)
__shared__ float softmax_smem[warps_per_block];
#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}
if (lane_id == 0) {
softmax_smem[warp_id] = S_max;
}
__syncthreads();
if (lane_id < warps_per_block) {
S_max = softmax_smem[lane_id];
} else {
S_max = -std::numeric_limits<float>::max();
}
#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}
S_max = __shfl_sync(shfl_mask, S_max, 0);
float exp_sum = 0.0;
#pragma unroll
for (size_t seqlen_k = threadIdx.x; seqlen_k < binfo.actual_seqlen_k; seqlen_k += threads_per_block) {
S_smem[seqlen_k] -= S_max;
S_smem[seqlen_k] = exp(S_smem[seqlen_k]);
exp_sum += S_smem[seqlen_k];
}
#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}
if (lane_id == 0) {
softmax_smem[warp_id] = exp_sum;
}
__syncthreads();
if (lane_id < warps_per_block) {
exp_sum = softmax_smem[lane_id];
}
#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}
exp_sum = __shfl_sync(shfl_mask, exp_sum, 0);
#pragma unroll
for (size_t seqlen_k = threadIdx.x; seqlen_k < binfo.actual_seqlen_k; seqlen_k += threads_per_block) {
S_smem[seqlen_k] /= exp_sum;
}
__syncthreads();
// O = P * V
half RV[thread_elem_nums];
fp8_t RQV[thread_elem_nums];
float RO[thread_elem_nums];
memset(RO, 0, sizeof(RO));
#pragma unroll
for (size_t base_seqlen_k = warp_id * groups_per_warp; base_seqlen_k < binfo.actual_seqlen_k;
base_seqlen_k += groups_per_block) {
size_t seqlen_k = base_seqlen_k + group_id;
if (seqlen_k < binfo.actual_seqlen_k) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RV[i * thread_copy_elem_nums]) =
*(int4 *)(¶ms.v_ptr[binfo.k_offset(seqlen_k, params.v_row_stride, params.v_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]);
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
RQV[i] = static_cast<fp8_t>(RV[i]);
RO[i] += (S_smem[seqlen_k] * static_cast<float>(RQV[i]));
}
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int2 *)(¶ms.v_fp8_ptr[binfo.k_offset(seqlen_k, params.v_row_stride, params.v_head_stride,
(i * group_size + group_lane_id) * thread_copy_elem_nums)]) =
*(int2 *)(&RQV[i * thread_copy_elem_nums]);
}
}
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
#pragma unroll
for (size_t j = group_size; j <= warp_size / 2; j *= 2) {
RO[i] += __shfl_xor_sync(shfl_mask, RO[i], j);
}
}
__syncthreads();
#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
S_smem[i] = 0.0;
}
__syncthreads();
if (lane_id < group_size) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
#pragma unroll
for (size_t j = 0; j < thread_copy_elem_nums; ++j) {
atomicAdd(S_smem + (i * group_size + lane_id) * thread_copy_elem_nums + j,
RO[i * thread_copy_elem_nums + j]);
}
}
}
__syncthreads();
#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
params.o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2half(S_smem[i]);
}
3.2 Test
(1) Test Conditions
- MHA: O = Softmax(Q * K^T) * V
- CUDA: 11.8
- GPU: RTX3090
- Flash Attention: v1.0.9
- Flash Attention v2: v2.1.0
- Cutlass: v3.1.0
- Head Num: 32
- Head Dim: 128
(2) Precision
Under the data distribution of -1.0~1.0, the accuracy relationship between testing FP8E5M2 quantization, FP8E4M3 quantization and int8 Per Head quantization is as follows:
FP8E4M3 > FP8E5M2 > int8 Per Head
E4M3 has one more mantissa than E5M2, so E4M3 has higher accuracy. Since int8 only has an exponent bit and no mantissa bit, the quantization accuracy of int8 Per Head is the worst.
(3) Latency
- Batch Size: 128
- Seq Q: 1
- Seq K: Seq Len
The latency relationship between testing FP8E5M2 quantization, FP8E4M3 quantization and int8 Per Head quantization is as follows:
FP8E4M3 > int8 Per Head > FP8E5M2
The reason for the large delay difference between the two FP8 data types of E4M3 and E5M2 is that their conversion efficiency is different from that of FP16 and FP32 data types. The conversion efficiency of E5M2 is higher than that of E4M3, so the delay of E5M2 is lower than that of E4M3.
4 Other
Int4 quantization optimization.