2.98. Matmul AvgPool GELU Scale Max

Calculates a linear transformation on the input tensor, then downsamples the resulting features via average pooling. Afterwards, it applies the GELU activation using a tanh-based approximation, scales the activated values by a constant factor, and finally computes the maximum value across the feature dimension.

Progress

Step 2 of 2
Reward
2.97x

Optimization Timeline

Step 2 Details

Reasoning

The current fused kernel computing "Matmul_AvgPool_GELU_Scale_Max" runs in 0.277 seconds, which is slower than the reference 0.228 seconds. One likely bottleneck is that each thread re‐loads the entire input row from global memory for every dot product. Since every thread in the block accesses the same input vector (of length in_features), we can load the entire x_row into shared memory once and then have every thread use that shared copy. Moreover, we can vectorize the dot product computation if in_features is divisible by 4, reducing loop iterations by a factor of 4. These improvements minimize global memory transactions and improve memory bandwidth usage. Plan: 1. At the beginning of the kernel, load the input sample’s x_row into shared memory using a loop that covers all in_features. 2. Use the shared copy (x_sh) for dot product computation. Use vectorized loads with float4 if possible. 3. Store each thread’s dot product result in another part of shared memory. 4. Continue with average pooling, activation, scaling, and max reduction as before. 5. Update the required shared memory size to (in_features + out_features) bytes. These changes should speed up the kernel by reducing repeated global memory loads and vectorizing inner loops. Let’s now provide the updated Python module and CUDA kernel accordingly.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel that performs the following for each sample:
// 1. Computes V = matmul(x, W^T) + bias for each output feature.
// 2. Averages V over non-overlapping groups of size pool_kernel_size.
// 3. Applies GELU activation on each pooled value and scales it.
// 4. Outputs the maximum among the processed pooled values.
//
// Optimization improvements:
// - Load the input row (x_row) into shared memory once for all threads.
// - Use vectorized loads (float4) for dot product if possible.
extern "C" __global__ void fused_kernel(const float* __restrict__ x,      // [batch, in_features]
                                          const float* __restrict__ weight, // [out_features, in_features]
                                          const float* __restrict__ bias,   // [out_features]
                                          float* __restrict__ output,       // [batch]
                                          int in_features,
                                          int out_features,
                                          int pool_kernel_size,
                                          float scale_factor) {
    // Each block processes one sample.
    int sample = blockIdx.x;
    const float* x_row = x + sample * in_features;
    
    // Allocate shared memory.
    // Layout: first in_features floats for shared copy of x_row, then out_features floats for dot products.
    extern __shared__ float sdata[];
    float* x_sh = sdata;                // size: in_features
    float* dot_sh = sdata + in_features; // size: out_features

    // Load the input x_row into shared memory.
    for (int i = threadIdx.x; i < in_features; i += blockDim.x) {
        x_sh[i] = x_row[i];
    }
    __syncthreads();
    
    // Phase 1: Linear transformation.
    int tid = threadIdx.x;
    if (tid < out_features) {
        float sum = 0.0f;
        const float* weight_row = weight + tid * in_features;
        if ((in_features & 3) == 0) {  // if in_features divisible by 4
            int vec_count = in_features / 4;
            const float4* x_vec = reinterpret_cast<const float4*>(x_sh);
            const float4* w_vec = reinterpret_cast<const float4*>(weight_row);
            for (int k = 0; k < vec_count; k++) {
                float4 x_val = x_vec[k];
                float4 w_val = w_vec[k];
                sum += x_val.x * w_val.x + x_val.y * w_val.y + x_val.z * w_val.z + x_val.w * w_val.w;
            }
        } else {
            for (int k = 0; k < in_features; k++) {
                sum += x_sh[k] * weight_row[k];
            }
        }
        sum += bias[tid];
        dot_sh[tid] = sum;
    }
    __syncthreads();
    
    // Phase 2: Average pooling over non-overlapping groups.
    int num_groups = out_features / pool_kernel_size; // assume divisible
    if (tid < num_groups) {
        int start = tid * pool_kernel_size;
        float group_sum = 0.0f;
        for (int j = 0; j < pool_kernel_size; j++) {
            group_sum += dot_sh[start + j];
        }
        float avg = group_sum / pool_kernel_size;
        // Store pooled value back in dot_sh.
        dot_sh[tid] = avg;
    }
    __syncthreads();
    
    // Phase 3: GELU activation and scaling.
    if (tid < num_groups) {
        float a = dot_sh[tid];
        float a_cubed = a * a * a;
        float gelu_val = 0.5f * a * (1.0f + tanhf(0.7978845608f * (a + 0.044715f * a_cubed)));
        float transformed = gelu_val * scale_factor;
        dot_sh[tid] = transformed;
    }
    __syncthreads();
    
    // Phase 4: Compute maximum over the transformed pooled values.
    if (tid == 0) {
        float max_val = dot_sh[0];
        for (int i = 1; i < num_groups; i++) {
            if (dot_sh[i] > max_val) {
                max_val = dot_sh[i];
            }
        }
        output[sample] = max_val;
    }
}

torch::Tensor forward(torch::Tensor x,
                      torch::Tensor weight,
                      torch::Tensor bias,
                      int in_features,
                      int out_features,
                      int pool_kernel_size,
                      float scale_factor) {
    // Use CUDAGuard to ensure correct device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch_size = x.size(0);
    auto output = torch::empty({batch_size}, x.options());
    
    // Set block dimension; we need at least out_features threads.
    int threads = (out_features > 0) ? out_features : 1;
    int blocks = batch_size;
    
    // Shared memory: (in_features + out_features) floats.
    size_t shared_mem_size = (in_features + out_features) * sizeof(float);
    
    fused_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        in_features,
        out_features,
        pool_kernel_size,
        scale_factor
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Matmul, AvgPool, GELU, Scale and Max (CUDA)");
}