Nvidia CUDA Core-CUDA HGEMV Optimization

How to extremely optimize CUDA HGEMV with CUDA Core?

Bruce-Lee-LY
7 min readOct 27, 2023

1 Background

GEMV (General Matrix Vector Multiplication) is a special GEMM (General Matrix Multiplication). Its optimization method on Nvidia GPU is different from GEMM. Cublas also provides some APIs (such as cublasSgemv and cublasDgemv, etc. ) directly calculate the GEMV of FP32 and FP64.

In the inference optimization of deep learning models, especially LLM (Large Language Model), the optimization of HGEMV (Half-precision General Matrix Vector Multiplication) is becoming increasingly important. However, cublas does not provide an API to directly calculate HGEMV. You can only use cublasGemmEx and other related APIs to indirectly call Tensor Core to calculate HGEMV.

The calculation of Tensor Core requires block input and output. Using Tensor Core to calculate HGEMV will cause a waste of hardware resources. For example, for MMA16816, 16 rows of results can be calculated at a time, while the effective number of rows of HGEMV is 1, and the Tensor Core hardware utilization is only 1/16. Normally, the computing power of Tensor Core FP16 in Nvidia GPU is 2 to 16 times that of CUDA Core FP32. For example, for RTX3090, in the case of FP32 multiplication and accumulation, the computing power of Tensor Core FP16 is only 2 times that of CUDA Core FP32. In this case, using CUDA Core to calculate HGEMV will achieve certain gains in terms of latency and hardware utilization while ensuring accuracy.

2 Result

This article mainly uses handwritten CUDA HGEMV Kernel to call CUDA Core, then performs performance tuning, and compares it with the Tensor Core performance of cublas. In order to ensure accuracy, FP32 is used for multiplication and accumulation. By exploring various parallel task designs, currently in the performance between dimensions 1 ~ 4096 is more than 1.5 times that of cublas. The code is open source in cuda_hgemv.

2.1 Test Conditions

  • HGEMV: C (1 * N, Half, Row Major) = A (1 * K, Half, Row Major) * B (K * N, Half, Col Major)
  • CUDA: 11.8
  • GPU: RTX3090

2.2 Equipment Specifications

The device specifications of RTX3090 are as follows.

2.3 RTX3090

  • K: 128

3 Compute

Due to the particularity of HGEMV calculation, when K is not particularly large, it tends to calculate one result in C at a time, that is, calculate the complete inner product of the A vector and a certain column in the B matrix. Here, choosing a thread to calculate a result in C or a warp to calculate a result in C mainly depends on the calculation memory access ratio. Furthermore, you can also try a warp to calculate multiple results in C to balance calculation and memory access to achieve better performance.

3.1 Thread Level

If you choose a thread to calculate a result in C, the code is as follows. You only need to calculate the complete inner product of the A vector and a column in the B matrix inside each thread. The source code is in cuda_hgemv.

#define WARP_SIZE 32
#define WARPS_PER_BLOCK 4
#define THREADS_PER_BLOCK 128 // WARP_SIZE * WARPS_PER_BLOCK

__global__ void threadNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C,
size_t N, size_t K) {
const size_t col = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (col >= N) {
return;
}

float tmp = 0.0;
#pragma unroll
for (size_t i = 0; i < K; ++i) {
tmp += __half2float(A[i]) * __half2float(B[i + col * K]);
}
C[col] = __float2half(tmp);
}

void threadNaive(half *A, half *B, half *C, size_t N, size_t K) {
dim3 block(THREADS_PER_BLOCK);
dim3 grid(div_ceil(N, THREADS_PER_BLOCK));

threadNaiveKernel<<<grid, block>>>(A, B, C, N, K);
}

3.2 Warp Level

If you choose a warp to calculate a result in C, the code is as follows. Different from threadNaive, each thread in warp here only calculates the partial inner product of the A vector and a certain column in the B matrix, so it needs to be done a reduce sum after the calculation is completed, the source code is in cuda_hgemv.

#define WARP_SIZE 32
#define WARPS_PER_BLOCK 4
#define THREADS_PER_BLOCK 128 // WARP_SIZE * WARPS_PER_BLOCK

__global__ void warp1NaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t N,
size_t K) {
const size_t warp_id = threadIdx.x / WARP_SIZE;
const size_t warp_col = blockIdx.x * WARPS_PER_BLOCK + warp_id;
if (warp_col >= N) {
return;
}

const size_t K_iters = div_ceil(K, WARP_SIZE);
const size_t lane_id = threadIdx.x % WARP_SIZE;

float tmp = 0.0;
#pragma unroll
for (size_t i = 0; i < K_iters; ++i) {
const size_t A_idx = i * WARP_SIZE + lane_id;
const size_t B_idx = i * WARP_SIZE + lane_id + warp_col * K;
tmp += __half2float(A[A_idx]) * __half2float(B[B_idx]);
}

const unsigned int mask = 0xffffffff;
#pragma unroll
for (size_t i = WARP_SIZE / 2; i >= 1; i /= 2) {
tmp += __shfl_xor_sync(mask, tmp, i);
}

if (lane_id == 0) {
C[warp_col] = __float2half(tmp);
}
}

void warp1Naive(half *A, half *B, half *C, size_t N, size_t K) {
dim3 block(THREADS_PER_BLOCK);
dim3 grid(div_ceil(N, WARPS_PER_BLOCK));

warp1NaiveKernel<<<grid, block>>>(A, B, C, N, K);
}

We consider increasing the amount of warp calculations, that is, one warp calculates two results in C. The code is as follows. The difference from warp1Naive is that the first 16 threads and the last 16 threads in warp internally calculate a vector in A and a certain value in matrix B respectively. The complete inner product of a column, so after the calculation is completed, a reduce sum needs to be performed inside the first 16 threads and the last 16 threads. The source code is in cuda_hgemv.

#define WARP_SIZE 32
#define WARPS_PER_BLOCK 4
#define THREADS_PER_BLOCK 128 // WARP_SIZE * WARPS_PER_BLOCK

#define COLS_PER_WARP 2
#define COLS_PER_BLOCK 8 // COLS_PER_WARP * WARPS_PER_BLOCK
#define GROUP_SIZE 16 // WARP_SIZE / COLS_PER_WARP

__global__ void warp2NaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t N,
size_t K) {
const size_t group_id = threadIdx.x / GROUP_SIZE;
const size_t group_col = blockIdx.x * COLS_PER_BLOCK + group_id;
if (group_col >= N) {
return;
}

const size_t K_iters = div_ceil(K, GROUP_SIZE);
const size_t group_lane_id = threadIdx.x % GROUP_SIZE;

float tmp = 0.0;
#pragma unroll
for (size_t i = 0; i < K_iters; ++i) {
const size_t A_idx = i * GROUP_SIZE + group_lane_id;
const size_t B_idx = i * GROUP_SIZE + group_lane_id + group_col * K;
tmp += __half2float(A[A_idx]) * __half2float(B[B_idx]);
}

constexpr unsigned int mask = 0xffffffff;
#pragma unroll
for (size_t i = GROUP_SIZE / 2; i >= 1; i /= 2) {
tmp += __shfl_xor_sync(mask, tmp, i);
}

if (group_lane_id == 0) {
C[group_col] = __float2half(tmp);
}
}

void warp2Naive(half *A, half *B, half *C, size_t N, size_t K) {
dim3 block(THREADS_PER_BLOCK);
dim3 grid(div_ceil(N, COLS_PER_BLOCK));

warp2NaiveKernel<<<grid, block>>>(A, B, C, N, K);
}

Similarly, we can continue to increase the amount of warp calculations, that is, a warp calculates 4, 8, and 16 results in C. The code is similar to warp2Naive, and the source code is in cuda_hgemv.

4 Memory Access

Since HGEMV calculates one result in C at a time, that is, calculates the complete inner product of the A vector and a certain column in the B matrix. Inside the block, the A vector needs to be accessed repeatedly, and the B matrix can only be accessed by the thread that calculates the column. Therefore, loading the A vector into shared memory in advance will help alleviate frequent access to the global memory of the A vector, but it will increase the data once. At the same time, earnings expectations will decrease.

Taking threadNaive as an example, the A vector is loaded into shared memory in advance. The code is as follows, and the source code is in cuda_hgemv.

#define WARP_SIZE 32
#define WARPS_PER_BLOCK 4
#define THREADS_PER_BLOCK 128 // WARP_SIZE * WARPS_PER_BLOCK

__global__ void threadSmemKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t N,
size_t K) {
extern __shared__ half A_smem[];
size_t A_smem_iters = div_ceil(K, THREADS_PER_BLOCK);
#pragma unroll
for (size_t i = 0; i < A_smem_iters; ++i) {
size_t idx = i * THREADS_PER_BLOCK + threadIdx.x;
A_smem[idx] = A[idx];
}

__syncthreads();

const size_t col = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (col >= N) {
return;
}

float tmp = 0.0;
#pragma unroll
for (size_t i = 0; i < K; ++i) {
tmp += __half2float(A_smem[i]) * __half2float(B[i + col * K]);
}
C[col] = __float2half(tmp);
}

size_t initThreadSmem(size_t K) {
int dev_id = 0;
HGEMV_CHECK_CUDART_ERROR(cudaGetDevice(&dev_id));

cudaDeviceProp dev_prop;
HGEMV_CHECK_CUDART_ERROR(cudaGetDeviceProperties(&dev_prop, dev_id));

size_t smem_max_size = K * sizeof(half);
HLOG("smem_max_size: %.0f KBytes (%zu bytes)", static_cast<double>(smem_max_size) / 1024, smem_max_size);

HGEMV_CHECK_GT(dev_prop.sharedMemPerMultiprocessor, smem_max_size);
HGEMV_CHECK_CUDART_ERROR(
cudaFuncSetAttribute(threadSmemKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max_size));

return smem_max_size;
}

void threadSmem(half *A, half *B, half *C, size_t N, size_t K) {
static size_t smem_max_size = initThreadSmem(K);

dim3 block(THREADS_PER_BLOCK);
dim3 grid(div_ceil(N, THREADS_PER_BLOCK));

threadSmemKernel<<<grid, block, smem_max_size>>>(A, B, C, N, K);
}

5 Other

5.1 Optimization Method

This article mainly introduces the general optimization method of CUDA HGEMV, that is, the balance between calculation and memory access.

5.2 Source Code

All optimization methods used in this article are open sourced in cuda_hgemv, and may be analyzed in conjunction with the source code later.

--

--