Decoding Attention-LLM Inference Optimization
How to optimize MHA in the decoding stage of LLM inference?
1 Background
Thanks to flash-attention’s accelerated optimization of MHA (Multi Head Attention) calculations, the performance of LLM training and inference has been greatly improved. Tri Dao specially wrote Flash Decoding based on Flash Attention v2 to optimize inference. It mainly modified the block division method, used multiple blocks to process the same attention head, realized parallel loading of KV Cache, and finally launched a single kernel to rescale and merge the results. Therefore, in the case of small batches and very long sequences, the inference performance of Flash Decoding is greatly improved.
However, when Flash Attention is applied to LLM inference, there are still the following problems:
- The Tensor Core utilization in the decoding stage is low. Taking the Ampere architecture (NVIDIA A100, RTX3090, etc.) as an example, the Tensor Core utilization is only 1 / 16
- Only FP16 and BF16 data types are supported, and quantification algorithms such as FP8, Int8 and Int4 are not supported, which is not conducive to reducing reasoning costs
- Only supports general Attention, and does not support variant Attention such as ALiBi
Since seq_q in the decoding stage of LLM inference is always 1, the calculation process of each head attention of MHA is simplified to two HGEMV and one softmax, as shown in the schematic diagram below.
It is noted that in the case of RTX3090 in FP32 multiplication and accumulation, the computing power of Tensor Core FP16 is only 2 times that of CUDA Core FP32. In this case, the handwritten CUDA Core kernel is used to calculate the decoding MHA, while ensuring the accuracy, regardless of the delay or hardware utilization will achieve certain benefits. On the other hand, the handwritten decoding MHA kernel not only facilitates the further use of KV cache quantification methods to improve the throughput performance of LLM inference, thereby reducing the cost of inference, but can also support more variants of attention (such as ALiBi, etc.).
2 Result
This article mainly refers to ppl.llm.kernel.cuda and flash-attention, and uses CUDA Core to optimize the performance of MHA in the decoding stage of LLM inference. Currently, in some inference decoding scenarios, the performance of Decoding Attention is better than Flash Decoding (Flash Attention) and FlashInfer. Decoding Attention supports GQA (Group Query Attention) / MQA (Multi Query Attention) and ALiBi (Attention with Linear Biases) inference scenarios and supports both FP16 and BF16 data types. Decoding Attention provides C++ API and Python API and the code is open source in decoding_attention.
2.1 Test Conditions
- MHA: O = Softmax(Q * K^T) * V
- CUDA: 12.1
- GPU: RTX3090
- Flash Attention: v2.6.3
- FlashInfer: v0.1.6
- Head Num: 32
- Head Dim: 128
- Data Type: FP16
2.2 Equipment Specifications
The device specifications of RTX3090 are as follows.
2.3 RTX3090
(1)Seq Len
The performance of Decoding Attention is better when the sequence length is below 1536, while the performance of Flash Decoding (Flash Attention) and FlashInfer is better when the sequence length is above 1536.
- Batch Size: 1
- Seq Q: 1
- Seq K: Seq Len
(2)Batch Size
Regardless of bacth size, Decoding Attention has better performance than Flash Decoding (Flash Attention) and FlashInfer.
- Batch Size: Batch Size
- Seq Q: 1
- Seq K: 128
3 Decoding Attention
As mentioned above, seq_q in the decoding stage of LLM inference is always 1, so the calculation process of each head attention of MHA is simplified to two HGEMV and one softmax. Decoding Attention divides blocks according to batch and head. Kernel is mainly divided into three parts, namely HGEMV (S = Q * K^T), softmax (P = Softmax(S)) and HGEMV (O = P * V). The source code is in decoding_attention.
3.1 HGEMV(S = Q * K^T)
Each group processes one or more seqlen_k, and the calculation process is consistent with HGEMV, refer to cuda_hgemv.
// S = Q * K^T
T RQ[thread_elem_nums];
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RQ[i * thread_copy_elem_nums]) =
*(int4 *)(&q_ptr[binfo.q_offset(params.q_row_stride, params.q_head_stride,
(i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
}
extern __shared__ float S_smem[];
float S_max = -std::numeric_limits<float>::max();
#pragma unroll
for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k;
base_seq_k += groups_per_block) {
size_t seq_k = base_seq_k + group_id;
T RK[thread_elem_nums];
float acc = 0.0;
if (seq_k < binfo.actual_seq_k) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RK[i * thread_copy_elem_nums]) =
*(int4 *)(&k_ptr[binfo.k_offset(seq_k, params.k_row_stride, params.k_head_stride,
(i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
if constexpr (std::is_same_v<T, half>) {
acc += (__half2float(RQ[i]) * __half2float(RK[i]));
} else {
acc += (__bfloat162float(RQ[i]) * __bfloat162float(RK[i]));
}
}
}
#pragma unroll
for (size_t i = threads_per_group / 2; i >= 1; i /= 2) {
acc += __shfl_xor_sync(shfl_mask, acc, i);
}
if (group_lane_id == 0 && seq_k < binfo.actual_seq_k) {
acc *= params.scale_softmax;
if (IsAlibi) {
acc += (binfo.h_slope * (static_cast<int>(seq_k) - binfo.actual_seq_q - binfo.row_shift));
}
S_smem[seq_k] = acc;
S_max = fmaxf(acc, S_max);
}
}
3.2 Softmax(P = Softmax(S))
First, reduce is performed based on the maximum value of S in each group calculated in the previous step to obtain the maximum value of S in a row, and then the softmax corresponding to each seqlen_k is calculated on S_smem.
// P = Softmax(S)
__shared__ float softmax_smem[warps_per_block];
#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}
if (lane_id == 0) {
softmax_smem[warp_id] = S_max;
}
__syncthreads();
if (lane_id < warps_per_block) {
S_max = softmax_smem[lane_id];
} else {
S_max = -std::numeric_limits<float>::max();
}
#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
}
S_max = __shfl_sync(shfl_mask, S_max, 0);
float exp_sum = 0.0;
#pragma unroll
for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) {
S_smem[seq_k] -= S_max;
S_smem[seq_k] = exp(S_smem[seq_k]);
exp_sum += S_smem[seq_k];
}
#pragma unroll
for (size_t i = warp_size / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}
if (lane_id == 0) {
softmax_smem[warp_id] = exp_sum;
}
__syncthreads();
if (lane_id < warps_per_block) {
exp_sum = softmax_smem[lane_id];
}
#pragma unroll
for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
}
exp_sum = __shfl_sync(shfl_mask, exp_sum, 0);
#pragma unroll
for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) {
S_smem[seq_k] /= exp_sum;
}
__syncthreads();
3.3 HGEMV(O = P * V)
Due to the particularity of V matrix storage, each group here calculates the outer product of each row or multiple rows in V, and then Reduce Sum gets the final result.
// O = P * V
T RV[thread_elem_nums];
float RO[thread_elem_nums];
memset(RO, 0, sizeof(RO));
#pragma unroll
for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k;
base_seq_k += groups_per_block) {
size_t seq_k = base_seq_k + group_id;
if (seq_k < binfo.actual_seq_k) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
*(int4 *)(&RV[i * thread_copy_elem_nums]) =
*(int4 *)(&v_ptr[binfo.k_offset(seq_k, params.v_row_stride, params.v_head_stride,
(i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
if constexpr (std::is_same_v<T, half>) {
RO[i] += (S_smem[seq_k] * __half2float(RV[i]));
} else {
RO[i] += (S_smem[seq_k] * __bfloat162float(RV[i]));
}
}
}
}
#pragma unroll
for (size_t i = 0; i < thread_elem_nums; ++i) {
#pragma unroll
for (size_t j = threads_per_group; j <= warp_size / 2; j *= 2) {
RO[i] += __shfl_xor_sync(shfl_mask, RO[i], j);
}
}
__syncthreads();
#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
S_smem[i] = 0.0;
}
__syncthreads();
if (lane_id < threads_per_group) {
#pragma unroll
for (size_t i = 0; i < thread_iters; ++i) {
#pragma unroll
for (size_t j = 0; j < thread_copy_elem_nums; ++j) {
atomicAdd(S_smem + (i * threads_per_group + lane_id) * thread_copy_elem_nums + j,
RO[i * thread_copy_elem_nums + j]);
}
}
}
__syncthreads();
#pragma unroll
for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
if constexpr (std::is_same_v<T, half>) {
o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2half(S_smem[i]);
} else {
o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2bfloat16(S_smem[i]);
}
}
4 Other
4.1 Next Plan
- Kernel Optimization
- KV Cache Quantization: FP8、Int8、Int4