2.59. Matmul with Swish Activation and Scaling
Performs a matrix multiplication on input data, applies an element-wise Swish activation by multiplying each element with its sigmoid value, and scales the resulting values by a constant factor.
Progress
Optimization Timeline
Step 6 Details
Reasoning
Our previous tiled GEMM fused kernel achieved a runtime of 0.163 sec versus a native runtime of 0.148 sec (reward ≈0.91). To further optimize and exceed native performance (reward >1), we can modify the tiling strategy to increase occupancy and reduce per-thread overhead. In particular, we change the tile size along the batch (N) dimension by increasing TILE_M from 16 to 32. This allows each block to process more rows (samples) and reduces the number of blocks required along the N dimension, thereby reducing kernel launch overhead and better utilizing shared memory. In addition, we use __ldg for bias loads to benefit from the read‐only cache. We also add unrolling hints in the inner loop over the K dimension to minimize loop overhead. These micro-optimizations are standard and should yield a genuine speedup. --- FIX --- The output mismatch was most likely caused by an error in the way the weight matrix tile was loaded. We are computing a GEMM for Y = X * W^T, where X is [N,K] and W is [M,K] so that Y[i,j] = dot(X[i, :], W[j, :]). To perform GEMM using tiling, we need to load a tile of X and a tile of B, where B = W^T. In our previous kernel, we mistakenly loaded W’s tile into shared memory with swapped dimensions. To fix this we need to declare the shared memory for the weight tile as a TILE_K × TILE_N array (instead of TILE_N × TILE_K). Then, each thread should load B[t,tx] = W[col, t] into shared memory (using threadIdx.y for the K dimension) so that later the multiplication is performed as: sum += sX[threadIdx.y][t] * sW[t][threadIdx.x]; This corrects the transposition needed for the GEMM. In this revision we also handle bounds correctly when loading the weight tile, ensuring only threads with threadIdx.y < TILE_K participate. Finally, we use __ldg for bias load and maintain our tiling and unrolling strategy. This change should fix the output mismatch. Now I provide the updated CUDA kernel code.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// New tile dimensions: Increase TILE_M to 32 for better occupancy.
#define TILE_M 32
#define TILE_N 16
#define TILE_K 16
// This kernel performs a tiled GEMM to compute:
// Y = X * W^T, where X is [N, K] and W is [M, K].
// Here, B = W^T has shape [K, M].
// After accumulating the dot product, bias is added (using __ldg),
// then the Swish activation is applied: out = scaling_factor * (a * sigmoid(a)),
// where a = sum + bias.
// Grid: gridDim.x = ceil(N / TILE_M), gridDim.y = ceil(M / TILE_N)
// Block: blockDim.x = TILE_N, blockDim.y = TILE_M.
extern "C" {
__global__ void fused_linear_swish_gemm_kernel(const float* __restrict__ X,
const float* __restrict__ W,
const float* __restrict__ bias,
float scaling_factor,
float* __restrict__ Y,
int N, // rows in X (batch size)
int K, // in_features
int M) // out_features (rows in W)
{
// Compute global row index in X and Y.
int row = blockIdx.x * TILE_M + threadIdx.y;
// Compute global column index in Y (and corresponding row of W).
int col = blockIdx.y * TILE_N + threadIdx.x;
float sum = 0.0f;
// Number of tiles along the K dimension.
int num_tiles = (K + TILE_K - 1) / TILE_K;
// Shared memory for tiles of X and B (B = W^T).
__shared__ float sX[TILE_M][TILE_K];
// Declare sB with dimensions [TILE_K][TILE_N], where sB will hold a tile of B.
__shared__ float sB[TILE_K][TILE_N];
for (int tile = 0; tile < num_tiles; tile++) {
// Load tile from X into shared memory.
int k_index = tile * TILE_K + threadIdx.x;
if (row < N && k_index < K)
sX[threadIdx.y][threadIdx.x] = X[row * K + k_index];
else
sX[threadIdx.y][threadIdx.x] = 0.0f;
// Load tile from B = W^T into shared memory.
// B[t][col] = W[col][t]. We want sB[t][threadIdx.x] = B[t][col].
// Only first TILE_K threads along y participate.
if (threadIdx.y < TILE_K) {
int k_index_b = tile * TILE_K + threadIdx.y;
if (col < M && k_index_b < K)
sB[threadIdx.y][threadIdx.x] = W[col * K + k_index_b];
else
sB[threadIdx.y][threadIdx.x] = 0.0f;
}
__syncthreads();
// Multiply the two tiles together.
#pragma unroll
for (int t = 0; t < TILE_K; t++) {
sum += sX[threadIdx.y][t] * sB[t][threadIdx.x];
}
__syncthreads();
}
// Write result if indices are in bounds.
if (row < N && col < M) {
// Use __ldg to load bias from read-only cache.
float a = sum + __ldg(&bias[col]);
float sig = 1.0f / (1.0f + __expf(-a));
float swish = a * sig;
Y[row * M + col] = scaling_factor * swish;
}
}
} // extern "C"
torch::Tensor forward(torch::Tensor input,
torch::Tensor weight,
torch::Tensor bias,
float scaling_factor) {
c10::cuda::CUDAGuard device_guard(input.device());
TORCH_CHECK(input.dim() == 2, "Input must be a 2D tensor");
TORCH_CHECK(weight.dim() == 2, "Weight must be a 2D tensor");
TORCH_CHECK(bias.dim() == 1, "Bias must be a 1D tensor");
int N = input.size(0); // Number of rows in X (batch size)
int K = input.size(1); // in_features
int M = weight.size(0); // out_features
TORCH_CHECK(weight.size(1) == K, "Weight's second dimension must equal in_features");
TORCH_CHECK(bias.size(0) == M, "Bias vector size must equal out_features");
auto output = torch::empty({N, M}, input.options());
// Define grid and block dimensions.
int grid_rows = (N + TILE_M - 1) / TILE_M;
int grid_cols = (M + TILE_N - 1) / TILE_N;
dim3 grid_dim(grid_rows, grid_cols);
dim3 block_dim(TILE_N, TILE_M);
// Calculate shared memory size: sX (TILE_M*TILE_K) + sB (TILE_K*TILE_N)
size_t shared_mem = (TILE_M * TILE_K + TILE_K * TILE_N) * sizeof(float);
fused_linear_swish_gemm_kernel<<<grid_dim, block_dim, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
scaling_factor,
output.data_ptr<float>(),
N, K, M
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused linear, Swish activation, and scaling using optimized tiled GEMM");
}