2.59. Matmul with Swish Activation and Scaling

Performs a matrix multiplication on input data, applies an element-wise Swish activation by multiplying each element with its sigmoid value, and scales the resulting values by a constant factor.

Progress

Step 6 of 6
Reward
1.41x

Optimization Timeline

Step 6 Details

Reasoning

Our previous tiled GEMM fused kernel achieved a runtime of 0.163 sec versus a native runtime of 0.148 sec (reward ≈0.91). To further optimize and exceed native performance (reward >1), we can modify the tiling strategy to increase occupancy and reduce per-thread overhead. In particular, we change the tile size along the batch (N) dimension by increasing TILE_M from 16 to 32. This allows each block to process more rows (samples) and reduces the number of blocks required along the N dimension, thereby reducing kernel launch overhead and better utilizing shared memory. In addition, we use __ldg for bias loads to benefit from the read‐only cache. We also add unrolling hints in the inner loop over the K dimension to minimize loop overhead. These micro-optimizations are standard and should yield a genuine speedup. --- FIX --- The output mismatch was most likely caused by an error in the way the weight matrix tile was loaded. We are computing a GEMM for Y = X * W^T, where X is [N,K] and W is [M,K] so that Y[i,j] = dot(X[i, :], W[j, :]). To perform GEMM using tiling, we need to load a tile of X and a tile of B, where B = W^T. In our previous kernel, we mistakenly loaded W’s tile into shared memory with swapped dimensions. To fix this we need to declare the shared memory for the weight tile as a TILE_K × TILE_N array (instead of TILE_N × TILE_K). Then, each thread should load B[t,tx] = W[col, t] into shared memory (using threadIdx.y for the K dimension) so that later the multiplication is performed as:   sum += sX[threadIdx.y][t] * sW[t][threadIdx.x]; This corrects the transposition needed for the GEMM. In this revision we also handle bounds correctly when loading the weight tile, ensuring only threads with threadIdx.y < TILE_K participate. Finally, we use __ldg for bias load and maintain our tiling and unrolling strategy. This change should fix the output mismatch. Now I provide the updated CUDA kernel code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// New tile dimensions: Increase TILE_M to 32 for better occupancy.
#define TILE_M 32
#define TILE_N 16
#define TILE_K 16

// This kernel performs a tiled GEMM to compute:
//  Y = X * W^T, where X is [N, K] and W is [M, K].
// Here, B = W^T has shape [K, M].
// After accumulating the dot product, bias is added (using __ldg),
// then the Swish activation is applied: out = scaling_factor * (a * sigmoid(a)),
// where a = sum + bias.
// Grid: gridDim.x = ceil(N / TILE_M), gridDim.y = ceil(M / TILE_N) 
// Block: blockDim.x = TILE_N, blockDim.y = TILE_M.
extern "C" {

__global__ void fused_linear_swish_gemm_kernel(const float* __restrict__ X,
                                               const float* __restrict__ W,
                                               const float* __restrict__ bias,
                                               float scaling_factor,
                                               float* __restrict__ Y,
                                               int N,    // rows in X (batch size)
                                               int K,    // in_features
                                               int M)    // out_features (rows in W)
{
    // Compute global row index in X and Y.
    int row = blockIdx.x * TILE_M + threadIdx.y;
    // Compute global column index in Y (and corresponding row of W).
    int col = blockIdx.y * TILE_N + threadIdx.x;

    float sum = 0.0f;

    // Number of tiles along the K dimension.
    int num_tiles = (K + TILE_K - 1) / TILE_K;
    
    // Shared memory for tiles of X and B (B = W^T).
    __shared__ float sX[TILE_M][TILE_K];
    // Declare sB with dimensions [TILE_K][TILE_N], where sB will hold a tile of B.
    __shared__ float sB[TILE_K][TILE_N];
    
    for (int tile = 0; tile < num_tiles; tile++) {
        // Load tile from X into shared memory.
        int k_index = tile * TILE_K + threadIdx.x;
        if (row < N && k_index < K)
            sX[threadIdx.y][threadIdx.x] = X[row * K + k_index];
        else
            sX[threadIdx.y][threadIdx.x] = 0.0f;
        
        // Load tile from B = W^T into shared memory.
        // B[t][col] = W[col][t]. We want sB[t][threadIdx.x] = B[t][col].
        // Only first TILE_K threads along y participate.
        if (threadIdx.y < TILE_K) {
            int k_index_b = tile * TILE_K + threadIdx.y;
            if (col < M && k_index_b < K)
                sB[threadIdx.y][threadIdx.x] = W[col * K + k_index_b];
            else
                sB[threadIdx.y][threadIdx.x] = 0.0f;
        }
        
        __syncthreads();
        
        // Multiply the two tiles together.
        #pragma unroll
        for (int t = 0; t < TILE_K; t++) {
            sum += sX[threadIdx.y][t] * sB[t][threadIdx.x];
        }
        __syncthreads();
    }
    // Write result if indices are in bounds.
    if (row < N && col < M) {
        // Use __ldg to load bias from read-only cache.
        float a = sum + __ldg(&bias[col]);
        float sig = 1.0f / (1.0f + __expf(-a));
        float swish = a * sig;
        Y[row * M + col] = scaling_factor * swish;
    }
}

} // extern "C"

torch::Tensor forward(torch::Tensor input,
                      torch::Tensor weight,
                      torch::Tensor bias,
                      float scaling_factor) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    TORCH_CHECK(input.dim() == 2, "Input must be a 2D tensor");
    TORCH_CHECK(weight.dim() == 2, "Weight must be a 2D tensor");
    TORCH_CHECK(bias.dim() == 1, "Bias must be a 1D tensor");
    
    int N = input.size(0);    // Number of rows in X (batch size)
    int K = input.size(1);    // in_features
    int M = weight.size(0);   // out_features
    
    TORCH_CHECK(weight.size(1) == K, "Weight's second dimension must equal in_features");
    TORCH_CHECK(bias.size(0) == M, "Bias vector size must equal out_features");

    auto output = torch::empty({N, M}, input.options());
    
    // Define grid and block dimensions.
    int grid_rows = (N + TILE_M - 1) / TILE_M;
    int grid_cols = (M + TILE_N - 1) / TILE_N;
    dim3 grid_dim(grid_rows, grid_cols);
    dim3 block_dim(TILE_N, TILE_M);
    
    // Calculate shared memory size: sX (TILE_M*TILE_K) + sB (TILE_K*TILE_N)
    size_t shared_mem = (TILE_M * TILE_K + TILE_K * TILE_N) * sizeof(float);
    
    fused_linear_swish_gemm_kernel<<<grid_dim, block_dim, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        scaling_factor,
        output.data_ptr<float>(),
        N, K, M
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused linear, Swish activation, and scaling using optimized tiled GEMM");
}