2.67. Convolution, GELU Activation, and Global Average Pooling

Computes a 2D convolution on an input tensor, applies the Gaussian Error Linear Unit (GELU) using its tanh-based approximation to introduce non-linearity, and then performs global average pooling to aggregate spatial features into a single value per channel.

Progress

Step 2 of 2
Reward
1.48x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize our fused GELU + global average pooling kernel by replacing the shared memory reduction with a warp-level reduction that leverages __shfl_down_sync. This avoids excessive synchronization and memory accesses when summing partial results within each block. In this optimization, each thread computes a partial sum of the GELU activated values over its assigned spatial elements, then reduces the sum within its warp using warp intrinsic functions. The warp leader writes its result into shared memory, and a final reduction over the warp sums yields the result for each (batch, channel) pair. This revised reduction should reduce latency and improve throughput. The Python module remains largely unchanged since the interface is identical. --- FIX --- The previous version uses the built‐in variable warpSize in the host code when computing shared memory size. Although warpSize is available in device code, when used on the host side it may raise issues because it is not defined there. To fix this, we define a constant (e.g. WARP_SIZE = 32) and use that constant in both host and device code. This avoids potential compilation errors and clarifies that our warp size is 32. The rest of the kernel logic remains unchanged. The Python module stays the same.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Define warp size constant to use in host and device code.
#define WARP_SIZE 32

// Fused GELU + Global Average Pooling Kernel
// Each block processes one (batch, channel) pair.
// Input: tensor of shape [B, C, H, W]
// Output: tensor of shape [B, C]
// For each block, we compute: output[b,c] = (1/(H*W)) * sum_{i=0}^{H*W-1} gelu(input[b,c,i])
// Where gelu(x) = 0.5 * x * (1 + tanhf(0.79788456f * (x + 0.044715f * x^3)))
__global__ void fused_gelu_pool_kernel(const float* __restrict__ input,
                                         float* __restrict__ output,
                                         int C, int H, int W) {
    // Determine (batch, channel) index from blockIdx.x.
    int idx = blockIdx.x;  // ranges from 0 to (B * C - 1)
    int b = idx / C;       // batch index
    int c = idx % C;       // channel index

    int n = H * W; // total number of spatial elements per channel
    // Pointer to the beginning of (b, c)-th channel
    const float* in_ptr = input + b * C * n + c * n;

    float partial_sum = 0.0f;
    // Each thread processes a subset of the spatial elements.
    for (int i = threadIdx.x; i < n; i += blockDim.x) {
        float val = in_ptr[i];
        // Compute approximate GELU activation.
        float gelu_val = 0.5f * val * (1.0f + tanhf(0.79788456f * (val + 0.044715f * val * val * val)));
        partial_sum += gelu_val;
    }

    // Warp-level reduction using shuffle intrinsics.
    unsigned int mask = 0xffffffff;
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        partial_sum += __shfl_down_sync(mask, partial_sum, offset);
    }

    // Use external shared memory to store per-warp partial sums.
    extern __shared__ float sdata[];
    int lane = threadIdx.x & (WARP_SIZE - 1);
    int warpId = threadIdx.x / WARP_SIZE;
    if (lane == 0) {
        sdata[warpId] = partial_sum;
    }
    __syncthreads();

    // First warp reduces the values stored in shared memory.
    float warp_sum = 0.0f;
    int numWarps = blockDim.x / WARP_SIZE;
    if (threadIdx.x < numWarps) {
        warp_sum = sdata[threadIdx.x];
        for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
            warp_sum += __shfl_down_sync(mask, warp_sum, offset);
        }
    }
    // The first thread writes the final computed average to the output.
    if (threadIdx.x == 0) {
        output[idx] = warp_sum / float(n);
    }
}

torch::Tensor fused_gelu_pool(torch::Tensor input) {
    // Verify that input tensor is 4 dimensional: [B, C, H, W]
    TORCH_CHECK(input.dim() == 4, "Input tensor must be 4-dimensional (B, C, H, W)");

    int B = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);

    // Create output tensor of shape [B, C]
    auto output = torch::empty({B, C}, input.options());

    // Calculate number of blocks: one block per (batch, channel) pair.
    int blocks = B * C;
    int threads = 256;
    // Calculate shared memory size using the defined warp size constant.
    size_t shared_mem_size = (threads / WARP_SIZE) * sizeof(float);

    // Ensure the correct CUDA device is used.
    c10::cuda::CUDAGuard device_guard(input.device());

    // Launch the kernel on the current CUDA stream.
    fused_gelu_pool_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        C, H, W
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_gelu_pool", &fused_gelu_pool, "Fused GELU Activation and Global Average Pooling (optimized with warp-level reduction)");
}