2.39. Gemm Scale BatchNorm Operation

Calculates a matrix multiplication on the input, scales the result element-wise with a learnable parameter, and applies one-dimensional batch normalization to the output.

Progress

Step 2 of 2
Reward
1.00x

Optimization Timeline

Step 2 Details

Reasoning

We observed that our fused kernel produced nearly the same runtime as the native implementation, which indicates there is room for further optimization. One way to boost performance is to vectorize the dot‐product computation when the input feature dimension (I) is divisible by 4. This can reduce the number of iterations in the inner loop by a factor of four and improve memory throughput by taking advantage of the memory system’s vectorized loads and stores. In our updated kernel we add a parameter “use_vectorized” that is set to 1 if I%4==0 and 0 otherwise. When enabled, the kernel reinterprets the pointers to x_row and W_row as float4 pointers and performs a vectorized inner loop (with loop unrolling using #pragma unroll). These improvements genuinely help by reducing the iteration count and lowering memory bandwidth demands. The host function computes the value of use_vectorized and passes it to the kernel. The Python module remains largely unchanged except that it calls the updated extension. This approach is robust as long as I (i.e., the number of input features) is divisible by 4; otherwise, a fallback scalar code is used.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel to perform:
// 1. GEMM:    acc = sum_{k=0}^{I-1} x[i,k] * W[j,k] + linear_bias[j]
// 2. Scaling: y = acc * scale[j]
// 3. BN (inference): output = bn_weight[j] * (y - bn_running_mean[j]) / sqrt(bn_running_var[j] + eps) + bn_bias[j]
__global__ void fused_gemm_scale_bn_kernel(
    const float* __restrict__ x,           // (B, I)
    const float* __restrict__ W,           // (O, I) row-major
    const float* __restrict__ linear_bias, // (O,)
    const float* __restrict__ scale,       // (O,)
    const float* __restrict__ bn_running_mean, // (O,)
    const float* __restrict__ bn_running_var,  // (O,)
    const float* __restrict__ bn_weight,       // (O,)
    const float* __restrict__ bn_bias,         // (O,)
    float eps,
    int B, int I, int O,
    int use_vectorized,                    // 1 if I % 4 == 0, else 0
    float* __restrict__ output             // (B, O)
) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * O;
    if (index < total) {
        int i = index / O;  // batch index
        int j = index % O;  // output feature index
        float acc = 0.0f;
        const float* x_row = x + i * I;
        const float* W_row = W + j * I;
        if (use_vectorized) {
            int vecCount = I / 4;
            const float4* x_row_vec = reinterpret_cast<const float4*>(x_row);
            const float4* W_row_vec = reinterpret_cast<const float4*>(W_row);
            float sum = 0.0f;
            #pragma unroll
            for (int k = 0; k < vecCount; k++) {
                float4 xv = x_row_vec[k];
                float4 wv = W_row_vec[k];
                sum += xv.x * wv.x + xv.y * wv.y + xv.z * wv.z + xv.w * wv.w;
            }
            acc = sum;
        } else {
            for (int k = 0; k < I; k++) {
                acc += x_row[k] * W_row[k];
            }
        }
        acc += linear_bias[j];
        float y_val = acc * scale[j];
        float denom = sqrtf(bn_running_var[j] + eps);
        float normalized = (y_val - bn_running_mean[j]) / denom;
        float out_val = bn_weight[j] * normalized + bn_bias[j];
        output[index] = out_val;
    }
}

torch::Tensor forward(torch::Tensor x,
                      torch::Tensor W,
                      torch::Tensor linear_bias,
                      torch::Tensor scale,
                      torch::Tensor bn_running_mean,
                      torch::Tensor bn_running_var,
                      torch::Tensor bn_weight,
                      torch::Tensor bn_bias,
                      float eps,
                      int B, int I, int O) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    auto output = torch::empty({B, O}, x.options());
    
    int total = B * O;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    int use_vectorized = (I % 4 == 0) ? 1 : 0;
    
    fused_gemm_scale_bn_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        W.data_ptr<float>(),
        linear_bias.data_ptr<float>(),
        scale.data_ptr<float>(),
        bn_running_mean.data_ptr<float>(),
        bn_running_var.data_ptr<float>(),
        bn_weight.data_ptr<float>(),
        bn_bias.data_ptr<float>(),
        eps,
        B, I, O,
        use_vectorized,
        output.data_ptr<float>()
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused GEMM, scaling and BatchNorm (inference) CUDA kernel with vectorized dot product");
}