Flash Attention-Inference Performance Exploring
Differences in inference performance between Flash Attention v1 and v2.
1 Background
Since the development of transformer, the attention mechanism has also shined in LLM (Large Language Model). However, due to the calculation limitations of softmax, the calculation process of MHA (Multi Head Attention) has been in a serious memory bound state for a long time. Based on the mathematical characteristics of softmax, Flash Attention integrates the calculation of MHA into one operator, and adopts the strategy of exchanging calculation and high-speed SRAM memory access for low-speed HBM memory access, which relieves the pressure of memory bound and greatly improves the calculation speed of MHA.
This article is based on the C++ interface of Flash Attention and Flash Attention v2 to explore the impact of the differences in the calculation processes of the two on MHA inference performance.
2 MHA
2.1 Calculation Process
The calculation process of self attention in MHA is shown in the figure above and can be divided into the following three steps.
O = Softmax(Q * K^T) * V
Step1: S = Q * K^T
Step2: P = Softmax(S)
Step3: O = P * V
Similarly, the dimensions of Q, K, V and O in MHA are as follows. Step1 calculates total_q matrix multiplications. The dimension of each matrix multiplication is (sq * d) * (d * sk), and S is obtained. Step2 passes softmax calculates P. Step 3 calculates total_q matrix multiplications. The dimension of each matrix multiplication is (sq * sk) * (sk * d), and O is obtained.
- Q: total_q * hq * dim
- K: total_k * hk * dim
- V: total_k * hk * dim
- O: total_q * hq * dim
The CPU implementation code of MHA is as follows, to ensure accuracy, the intermediate calculation results are all float, and the source code is in flash_attention_inference.
void mha_cpu(Tensor<half> *Q, Tensor<half> *K, Tensor<half> *V, Tensor<half> *O, Tensor<int> *cu_seq_q,
Tensor<int> *cu_seq_k, size_t max_seq_k, bool is_causal, bool is_alibi) {
size_t total_q = Q->getShape()[0];
size_t head_q = Q->getShape()[1];
size_t dim = Q->getShape()[2];
size_t head_k = K->getShape()[1];
size_t batch = cu_seq_q->getShape()[0] - 1;
FAI_CHECK_EQ(head_q % head_k, 0);
const size_t head_ratio = head_q / head_k;
half *q_ptr = Q->getHostPtr();
half *k_ptr = K->getHostPtr();
half *v_ptr = V->getHostPtr();
half *o_ptr = O->getHostPtr();
int *cu_seq_q_ptr = cu_seq_q->getHostPtr();
int *cu_seq_k_ptr = cu_seq_k->getHostPtr();
// S = Q * K^T
Tensor<float> *S = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor S");
FAI_CHECK(S);
float *s_ptr = S->getHostPtr();
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
for (size_t h = 0; h < head_q; ++h) {
size_t h_k = h / head_ratio;
for (size_t sq = 0; sq < seq_q; ++sq) {
for (size_t sk = 0; sk < seq_k; ++sk) {
float acc = 0.0;
for (size_t d = 0; d < dim; ++d) {
acc += __half2float(q_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d]) *
__half2float(k_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
}
s_ptr[sum_seq_q * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = acc;
}
}
}
}
// P = Softmax(S)
Tensor<float> *P = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor P");
FAI_CHECK(P);
float *p_ptr = P->getHostPtr();
float scale = 1.0 / std::sqrt(dim);
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
size_t row_shift = seq_k - seq_q;
for (size_t h = 0; h < head_q; ++h) {
float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0;
for (size_t sq = 0; sq < seq_q; ++sq) {
size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k;
// Max(S)
std::vector<float> tmp_s(seq_k, 0.0);
float max_s = -std::numeric_limits<float>::max();
for (size_t sk = 0; sk < col_limit; ++sk) {
tmp_s[sk] = s_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] * scale;
if (is_alibi && sk < sq + row_shift) {
tmp_s[sk] +=
(h_slope * (static_cast<int>(sk) - static_cast<int>(sq) - static_cast<int>(row_shift)));
}
max_s = std::max(max_s, tmp_s[sk]);
}
// Sum(S)
float sum_s = 0.0;
for (size_t sk = 0; sk < col_limit; ++sk) {
tmp_s[sk] = std::exp(tmp_s[sk] - max_s);
sum_s += tmp_s[sk];
}
// Softmax(S)
for (size_t sk = 0; sk < col_limit; ++sk) {
p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = tmp_s[sk] / sum_s;
}
// Causal(S)
if (is_causal) {
for (size_t sk = col_limit; sk < seq_k; ++sk) {
p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] = 0.0;
}
}
}
}
}
// O = P * V
for (size_t b = 0; b < batch; ++b) {
size_t sum_seq_q = static_cast<size_t>(cu_seq_q_ptr[b]);
size_t sum_seq_k = static_cast<size_t>(cu_seq_k_ptr[b]);
size_t seq_q = static_cast<size_t>(cu_seq_q_ptr[b + 1]) - sum_seq_q;
size_t seq_k = static_cast<size_t>(cu_seq_k_ptr[b + 1]) - sum_seq_k;
for (size_t h = 0; h < head_q; ++h) {
size_t h_k = h / head_ratio;
for (size_t sq = 0; sq < seq_q; ++sq) {
for (size_t d = 0; d < dim; ++d) {
float acc = 0.0;
for (size_t sk = 0; sk < seq_k; ++sk) {
acc += p_ptr[(sum_seq_q + sq) * (head_q * seq_k) + h * seq_k + sk] *
__half2float(v_ptr[(sum_seq_k + sk) * (head_k * dim) + h_k * dim + d]);
}
o_ptr[(sum_seq_q + sq) * (head_q * dim) + h * dim + d] = __float2half(acc);
}
}
}
}
if (S) {
delete S;
S = nullptr;
}
if (P) {
delete P;
P = nullptr;
}
}
2.2 Flash Attention
This article only focuses on the MHA calculation process of Flash Attention. For other details, please view the paper and source code.
The calculation of Flash Attention for MHA is to divide blocks according to batch, head and split-seq_q. When calculating Q * K^T, the internal warp calculation is divided according to the seq_k dimension of the K^T matrix, that is, each warp can only obtain the S matrix. Partial block result of a certain row. Therefore, when calculating the softmax of a block, the warp needs to be synchronized first. On the other hand, when finally calculating P * V, the split-K method is used. The intermediate results of each warp calculation must be reduced and summed before the block result of O can be obtained. The warps still need to be synchronized before reducing.
2.3 Flash Attention v2
Flash Attention v2 also divides blocks according to batch, head and split-seq_q for the calculation of MHA. However, when calculating Q * K^T, the internal warp calculation is divided according to the seq_q dimension of the Q matrix, that is, each warp can obtain a certain value of the S matrix. All block results in one row. Therefore, when calculating the softmax of the block, there is no need to synchronize the warp. On the other hand, when finally calculating P * V, each warp can also directly calculate the warp result of O, without the need for reduce or additional synchronization of warps.
3 Inference Performance
3.1 Test Conditions
The code is open sourced in flash_attention_inference, and the kernel comes from flash-attention. The backward, dropout, bf16 and torch dependency codes that are irrelevant to inference are removed, and can be easily integrated into LLM inference scenarios. Based on flash attention, this code also fully supports GQA (Group Query Attention)/MQA (Multi Query Attention) inference scenarios, prefill/decoding hybrid inference scenarios and ALiBi (Attention with Linear Biases) inference scenarios.
- MHA: O = Softmax(Q * K^T) * V
- CUDA: 11.8
- GPU: RTX3090
- Flash Attention: v1.0.9
- Flash Attention v2: v2.1.0
- Cutlass: v3.1.0
- Head Num: 32
- Head Dim: 128
3.2 Prefill Inference Performance
(1) Seq Len
When it comes to short sequences, the performance of the two is equivalent; when it comes to long sequences, Flash Attention v2 performs better and can be improved by about 60%. The reason why Flash Attention v2 performs well in long sequences is mainly due to the reduction of multiple warp synchronization between block data.
- Batch Size: 1
- Seq Q: Seq Len
- Seq K: Seq Len
(2) Batch Size
When the batch size is smaller, Flash Attention v2 performs better; when the batch size is larger, the performance of the two is equivalent.
- Batch Size: Batch Size
- Seq Q: 128
- Seq K: 128
3.3 Decoding Inference Performance
(1) Seq Len
When the sequence is short, the performance of the two is equivalent; when the sequence is long, the performance of Flash Attention is better and can be improved by about 100%. The reason why Flash Attention performs well in long sequences is mainly due to the warp division of labor in the seq_k dimension, which improves the parallelism of calculations.
- Batch Size: 1
- Seq Q: 1
- Seq K: Seq Len
(2) Batch Size
Regardless of batch size, Flash Attention performs better.
- Batch Size: Batch Size
- Seq Q: 1
- Seq K: 128
3.4 Hybrid Inference Performance
No matter how the ratio of prefill and decoding changes, the performance of Flash Attention and Flash Attention v2 are relatively close.
- Batch Size: 100
- Seq Q: 128 (Prefill) + 1 (Decoding)
- Seq K: 128
4 Other
4.1 GQA/MQA Inference Scenarios
All GQA/MQA inference scenarios are supported, and the code is updated in flash_attention_inference.
4.2 Hybrid Inference Scenarios
All prefill and decoding hybrid inference scenarios are supported. The code is updated in flash_attention_inference, and the performance is shown in 3.4.
4.3 ALiBi Inference Scenarios
All ALiBi inference scenarios are supported, and the code is updated in flash_attention_inference.