1.2. Standard Matrix Multiplication
Calculates the product of two matrices by performing a matrix multiplication operation. Given an input matrix A of shape (M, K) and B of shape (K, N), it returns an output matrix of shape (M, N) representing the multiplication result.
Progress
Optimization Timeline
Step 8 Details
Reasoning
We analyzed further optimization strategies such as implementing double buffering to overlap global memory loads with computation. However, after weighing the complexity of implementing double buffering against the overall performance gain, we determined that our current kernel—with 4×4 register tiling, vectorized shared‑memory loads (with early full-tile branch), and inner-loop unrolling—already achieves a speedup over native implementations (reward >1). Further changes risk adding complexity without guaranteed significant benefits. Therefore, we finalize our current version as our optimal solution.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> // For c10::cuda::getCurrentCUDAStream()
#include <cuda.h>
#include <cuda_runtime.h>
#define TILE_SIZE 64 // Tile dimension.
#define SUB_TILE 4 // Each thread computes a 4x4 submatrix.
#define THREADS_PER_DIM (TILE_SIZE / SUB_TILE) // 64/4 = 16
// Optimized matrix multiplication kernel using 4x4 register tiling, shared memory tiling,
// vectorized shared-memory loads (with an early branch to bypass boundary checks on full tiles),
// and inner-loop unrolling in groups of 4 iterations.
__global__ void matmul_kernel(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ C,
int M, int N, int K) {
// Shared memory tiles for matrices A and B.
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
// Block-level offsets.
int blockRow = blockIdx.y * TILE_SIZE;
int blockCol = blockIdx.x * TILE_SIZE;
// Thread indices within the block.
int tx = threadIdx.x; // Range: 0 .. THREADS_PER_DIM-1
int ty = threadIdx.y; // Range: 0 .. THREADS_PER_DIM-1
// Starting indices in the global output for this thread's 4x4 submatrix.
int rowStart = blockRow + ty * SUB_TILE;
int colStart = blockCol + tx * SUB_TILE;
// Register accumulator for the 4x4 submatrix.
float Csub[SUB_TILE][SUB_TILE];
for (int i = 0; i < SUB_TILE; i++) {
for (int j = 0; j < SUB_TILE; j++) {
Csub[i][j] = 0.0f;
}
}
// Total number of tiles along the K dimension.
int numTiles = (K + TILE_SIZE - 1) / TILE_SIZE;
int totalThreads = blockDim.x * blockDim.y;
int linearThreadId = threadIdx.y * blockDim.x + threadIdx.x;
// Total number of float4 loads per tile.
int totalFloat4 = (TILE_SIZE * TILE_SIZE) / 4;
for (int t = 0; t < numTiles; t++) {
// Determine if the current tiles in A and B are fully within bounds.
bool fullTileA = ((blockRow + TILE_SIZE) <= M) && ((t * TILE_SIZE + TILE_SIZE) <= K);
bool fullTileB = ((t * TILE_SIZE + TILE_SIZE) <= K) && ((blockCol + TILE_SIZE) <= N);
// --- Load A tile into shared memory ---
for (int i = linearThreadId; i < totalFloat4; i += totalThreads) {
int idx = i * 4;
int r = idx / TILE_SIZE;
int c = idx % TILE_SIZE;
int globalRow = blockRow + r;
int globalCol = t * TILE_SIZE + c;
if (fullTileA) {
// Vectorized load: load 4 floats at once.
reinterpret_cast<float4*>(As)[i] = *reinterpret_cast<const float4*>(A + globalRow * K + globalCol);
} else {
// Fallback: load scalar-wise.
for (int k = 0; k < 4; k++) {
int colOffset = c + k;
int gcol = t * TILE_SIZE + colOffset;
if (globalRow < M && gcol < K)
As[r][colOffset] = A[globalRow * K + gcol];
else
As[r][colOffset] = 0.0f;
}
}
}
// --- Load B tile into shared memory ---
for (int i = linearThreadId; i < totalFloat4; i += totalThreads) {
int idx = i * 4;
int r = idx / TILE_SIZE;
int c = idx % TILE_SIZE;
int globalRow = t * TILE_SIZE + r;
int globalCol = blockCol + c;
if (fullTileB) {
reinterpret_cast<float4*>(Bs)[i] = *reinterpret_cast<const float4*>(B + globalRow * N + globalCol);
} else {
for (int k = 0; k < 4; k++) {
int colOffset = c + k;
int gcol = blockCol + colOffset;
if (globalRow < K && gcol < N)
Bs[r][colOffset] = B[globalRow * N + gcol];
else
Bs[r][colOffset] = 0.0f;
}
}
}
__syncthreads();
// --- Compute the product for the current tile ---
// Unroll the loop in groups of 4.
#pragma unroll
for (int k = 0; k < TILE_SIZE; k += 4) {
#pragma unroll
for (int offset = 0; offset < 4; offset++) {
float b0 = Bs[k + offset][tx * SUB_TILE + 0];
float b1 = Bs[k + offset][tx * SUB_TILE + 1];
float b2 = Bs[k + offset][tx * SUB_TILE + 2];
float b3 = Bs[k + offset][tx * SUB_TILE + 3];
Csub[0][0] += As[ty * SUB_TILE + 0][k + offset] * b0;
Csub[0][1] += As[ty * SUB_TILE + 0][k + offset] * b1;
Csub[0][2] += As[ty * SUB_TILE + 0][k + offset] * b2;
Csub[0][3] += As[ty * SUB_TILE + 0][k + offset] * b3;
Csub[1][0] += As[ty * SUB_TILE + 1][k + offset] * b0;
Csub[1][1] += As[ty * SUB_TILE + 1][k + offset] * b1;
Csub[1][2] += As[ty * SUB_TILE + 1][k + offset] * b2;
Csub[1][3] += As[ty * SUB_TILE + 1][k + offset] * b3;
Csub[2][0] += As[ty * SUB_TILE + 2][k + offset] * b0;
Csub[2][1] += As[ty * SUB_TILE + 2][k + offset] * b1;
Csub[2][2] += As[ty * SUB_TILE + 2][k + offset] * b2;
Csub[2][3] += As[ty * SUB_TILE + 2][k + offset] * b3;
Csub[3][0] += As[ty * SUB_TILE + 3][k + offset] * b0;
Csub[3][1] += As[ty * SUB_TILE + 3][k + offset] * b1;
Csub[3][2] += As[ty * SUB_TILE + 3][k + offset] * b2;
Csub[3][3] += As[ty * SUB_TILE + 3][k + offset] * b3;
}
}
__syncthreads();
}
// --- Write back the 4x4 submatrix from registers to global memory ---
for (int i = 0; i < SUB_TILE; i++) {
for (int j = 0; j < SUB_TILE; j++) {
int globalRow = rowStart + i;
int globalCol = colStart + j;
if (globalRow < M && globalCol < N)
C[globalRow * N + globalCol] = Csub[i][j];
}
}
}
torch::Tensor matmul(torch::Tensor A, torch::Tensor B) {
c10::cuda::CUDAGuard device_guard(A.device());
TORCH_CHECK(A.dim() == 2, "A must be 2D");
TORCH_CHECK(B.dim() == 2, "B must be 2D");
int M = A.size(0);
int K = A.size(1);
TORCH_CHECK(B.size(0) == K, "A.cols must equal B.rows");
int N = B.size(1);
auto C = torch::empty({M, N}, A.options());
dim3 threads(THREADS_PER_DIM, THREADS_PER_DIM); // 16x16 threads per block.
dim3 blocks((N + TILE_SIZE - 1) / TILE_SIZE,
(M + TILE_SIZE - 1) / TILE_SIZE);
matmul_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
M, N, K
);
return C;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("matmul", &matmul, "Optimized matrix multiplication with 4x4 register tiling, vectorized shared memory loads, and inner-loop unrolling");
}