2.12. Gemm Multiply LeakyReLU Operation

Calculates a matrix multiplication on the input data, scales the result by a constant multiplier, and then computes the LeakyReLU activation using a configurable negative slope.

Progress

Step 4 of 4
Reward
1.46x

Optimization Timeline

Step 4 Details

Reasoning

We can further optimize our tiled kernel by using the read-only __ldg intrinsic when loading values from global memory into shared memory. This change ensures that the data is fetched from the read–only cache, reducing load latency. We simply add __ldg() calls in the shared memory loading parts of the kernel. The rest of the tiled kernel remains unchanged. The Python module remains the same; it just loads our updated CUDA extension. This improvement genuinely helps by reducing memory latency and increasing throughput without altering the algorithm.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define TILE_SIZE 16

// Fused kernel using tiling, shared memory, and __ldg for efficient global memory access.
// Computes output: y[b, j] = LeakyReLU( (∑ₖ x[b,k]*W[j,k] + bias[j]) * multiplier )
// x: [batch_size, in_features]
// W: [out_features, in_features] (row-major, as in typical Linear layer)
// bias: [out_features]
// y: [batch_size, out_features]
__global__ void fused_forward_tiled_kernel(const float* __restrict__ x,
                                           const float* __restrict__ W,
                                           const float* __restrict__ bias,
                                           float multiplier,
                                           float negative_slope,
                                           float* __restrict__ y,
                                           int batch_size,
                                           int in_features,
                                           int out_features) {
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;  // batch index
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;    // output feature index

    __shared__ float shared_x[TILE_SIZE][TILE_SIZE];
    __shared__ float shared_W[TILE_SIZE][TILE_SIZE];

    float sum = 0.0f;
    int numTiles = (in_features + TILE_SIZE - 1) / TILE_SIZE;
    for (int t = 0; t < numTiles; t++) {
        int col_index = t * TILE_SIZE + threadIdx.x;
        if (row < batch_size && col_index < in_features)
            shared_x[threadIdx.y][threadIdx.x] = __ldg(&x[row * in_features + col_index]);
        else
            shared_x[threadIdx.y][threadIdx.x] = 0.0f;
        
        int row_index = t * TILE_SIZE + threadIdx.y;
        if (col < out_features && row_index < in_features)
            shared_W[threadIdx.y][threadIdx.x] = __ldg(&W[col * in_features + row_index]);
        else
            shared_W[threadIdx.y][threadIdx.x] = 0.0f;
        
        __syncthreads();
        
        #pragma unroll
        for (int i = 0; i < TILE_SIZE; i++) {
            sum += shared_x[threadIdx.y][i] * shared_W[i][threadIdx.x];
        }
        __syncthreads();
    }

    if (row < batch_size && col < out_features) {
        sum += bias[col];
        sum *= multiplier;
        float result = (sum >= 0.0f) ? sum : sum * negative_slope;
        y[row * out_features + col] = result;
    }
}

torch::Tensor fused_forward_tiled(torch::Tensor x,
                                  torch::Tensor W,
                                  torch::Tensor bias,
                                  float multiplier,
                                  float negative_slope) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch_size = x.size(0);
    int in_features = x.size(1);
    int out_features = W.size(0);

    auto y = torch::empty({batch_size, out_features}, x.options());
    
    dim3 block(TILE_SIZE, TILE_SIZE);
    dim3 grid((out_features + TILE_SIZE - 1) / TILE_SIZE,
              (batch_size + TILE_SIZE - 1) / TILE_SIZE);
    
    fused_forward_tiled_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        W.data_ptr<float>(),
        bias.data_ptr<float>(),
        multiplier,
        negative_slope,
        y.data_ptr<float>(),
        batch_size,
        in_features,
        out_features
    );
    
    return y;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_forward_tiled", &fused_forward_tiled, "Fused GEMM, multiplier, and LeakyReLU using tiled shared memory and __ldg");
}