2.51. Gemm-Subtract-GlobalAvgPool-LogSumExp-GELU-ResidualAdd

Calculates output by first transforming the input tensor via a general matrix multiplication, then subtracting a learnable parameter. It then aggregates features using global average pooling, performs a log-sum-exp reduction, applies the Gaussian Error Linear Unit (GELU) activation, and finally adds the original input back as a residual connection.

Progress

Step 2 of 2
Reward
3.21x

Optimization Timeline

Step 2 Details

Reasoning

We can further improve our fused kernel by optimizing the reduction step. In the original kernel, we used shared memory with a loop to reduce partial sums, but we can use warp‐level primitives (i.e. __shfl_down_sync) to reduce the overhead. In our revised kernel, each thread computes a partial dot product of the input sample with S (where S = sum(weight, dim=0)). Then, we perform an in‐warp reduction using shuffle instructions. Each warp’s leader then writes its result to shared memory; finally, one warp performs a second reduction on these values to obtain the final dot product. This optimized reduction minimizes shared memory usage and synchronization overhead. Once we have the dot product, we compute the pooled value as (dot + C)/D (with D equal to out_features and C computed as sum(bias-subtract)), apply the GELU activation and then broadcast the result to all elements in the sample (for the residual add). Our Python module stays largely the same while loading the fused CUDA extension. This improvement genuinely speeds up the computation by leveraging fast warp-level reduction. --- FIX --- We can further optimize our fused kernel by improving the reduction of the dot product using warp-level primitives and simplifying the final reduction step. In our revised kernel, each thread computes a partial dot product of a sample from x with the precomputed vector S. Then, we use __shfl_down_sync to perform a warp-level reduction of these partial sums. Instead of using a second warp-level reduction across warp leaders, we store each warp’s result in shared memory and have thread 0 perform a simple loop reduction over the warp sums. Then thread 0 computes the pooled value p = (total + C)/D, applies the GELU approximation to obtain gelu_val, and broadcasts that value (via shared memory) to all threads for the simple residual add. This minimizes synchronization and shared memory usage over the previous version. Our Python module remains nearly unchanged while loading this optimized CUDA extension.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Approximate GELU activation function.
__device__ inline float gelu(float x) {
    float x_cube = x * x * x;
    float tanh_arg = 0.7978845608f * (x + 0.044715f * x_cube);
    return 0.5f * x * (1.0f + tanhf(tanh_arg));
}

// Optimized fused kernel using warp-level reduction.
// Each block processes one sample of x (of shape [in_features]).
// S: precomputed vector S = sum(weight, dim=0) of size in_features.
// C: scalar computed as sum(bias - subtract).
// F: in_features, D: out_features.
__global__ void fused_kernel(const float* __restrict__ x,
                             float* __restrict__ out,
                             const float* __restrict__ S,
                             float C,
                             int F,    // in_features
                             int D)    // out_features (for pooling averaging)
{
    int sample = blockIdx.x;
    const float* x_sample = x + sample * F;
    float* out_sample = out + sample * F;
    
    float partial_sum = 0.0f;
    // Each thread processes multiple elements.
    for (int i = threadIdx.x; i < F; i += blockDim.x) {
        partial_sum += x_sample[i] * S[i];
    }
    
    // Warp-level reduction using __shfl_down_sync.
    unsigned int mask = 0xffffffff;
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        partial_sum += __shfl_down_sync(mask, partial_sum, offset);
    }
    
    // Each warp's lane 0 stores its result into shared memory.
    __shared__ float warp_sums[32];  // support up to 32 warps per block.
    int warp_id = threadIdx.x / warpSize;
    int lane = threadIdx.x % warpSize;
    if (lane == 0) {
        warp_sums[warp_id] = partial_sum;
    }
    __syncthreads();
    
    // Final reduction: thread 0 sums the results from each warp.
    float total = 0.0f;
    if (threadIdx.x == 0) {
        int num_warps = (blockDim.x + warpSize - 1) / warpSize;
        for (int i = 0; i < num_warps; i++) {
            total += warp_sums[i];
        }
        float p = (total + C) / ((float) D);
        float gelu_val = gelu(p);
        // Store the GELU result to shared memory for broadcast.
        warp_sums[0] = gelu_val;
    }
    __syncthreads();
    float gelu_val = warp_sums[0];
    
    // Perform residual addition: output = x + gelu_val (broadcasted).
    for (int i = threadIdx.x; i < F; i += blockDim.x) {
        out_sample[i] = x_sample[i] + gelu_val;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, torch::Tensor subtract, int in_features, int out_features) {
    // Use CUDAGuard to ensure proper device.
    c10::cuda::CUDAGuard device_guard(x.device());
    int batch = x.size(0);
    
    // Compute S = sum(weight, dim=0). weight shape: [out_features, in_features].
    auto S_tensor = weight.sum(0);
    // Compute C = sum(bias - subtract). Both bias and subtract are of shape [out_features].
    auto diff_tensor = bias - subtract;
    float C = diff_tensor.sum().item<float>();
    
    // Allocate output tensor.
    auto out = torch::empty_like(x);
    
    int threads = 256;
    dim3 grid(batch);
    dim3 block(threads);
    
    // Launch the optimized fused kernel.
    fused_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        out.data_ptr<float>(),
        S_tensor.data_ptr<float>(),
        C,
        in_features,
        out_features
    );
    
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Gemm, subtract, pooling, logsumexp, GELU, and residual add with warp-level optimization");
}