1.42. 2D Max Pooling

Performs a 2D max pooling operation that calculates the maximum value within each pooling window on a 4D input tensor. The computation uses adjustable parameters including kernel size, stride, padding, and dilation to control the pooling process.

Progress

Step 6 of 6
Reward
1.69x

Optimization Timeline

Step 6 Details

Reasoning

We can further improve the kernel by using the __ldg intrinsic for all global memory reads to take advantage of the read‐only cache. In the generic branch, we also add #pragma unroll hints to unroll the inner loops since the pooling window is generally small. These improvements will reduce instruction overhead and memory latency. The Python module stays the same.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>

// Compute output dimension using the standard pooling formula.
__host__ __device__ inline int compute_out_dim(int size, int padding, int dilation, int kernel_size, int stride) {
    return ((size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1;
}

__global__ void maxpool2d_kernel(const float* __restrict__ input,
                                 float* __restrict__ output,
                                 int batch,
                                 int channels,
                                 int H, int W,
                                 int outH, int outW,
                                 int kernel_size,
                                 int stride,
                                 int padding,
                                 int dilation) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    const int total = batch * channels * outH * outW;
    if (index < total) {
        // Determine output coordinates.
        int w_out = index % outW;
        int tmp = index / outW;
        int h_out = tmp % outH;
        tmp = tmp / outH;
        int c = tmp % channels;
        int n = tmp / channels;
        
        // Compute starting coordinates in the input.
        int h_start = h_out * stride - padding;
        int w_start = w_out * stride - padding;
        
        // Base index for the (n, c) slice.
        int base_index = n * (channels * H * W) + c * (H * W);
        
        float max_val = -FLT_MAX;
        
        // Specialized branch for kernel_size == 2.
        if (kernel_size == 2) {
            int h0 = h_start;
            int h1 = h_start + dilation;
            int w0 = w_start;
            int w1 = w_start + dilation;
            int in_index;
            // Check (h0, w0)
            if (h0 >= 0 && h0 < H && w0 >= 0 && w0 < W) {
                in_index = base_index + h0 * W + w0;
                max_val = __ldg(&input[in_index]);
            }
            // Check (h0, w1)
            if (h0 >= 0 && h0 < H && w1 >= 0 && w1 < W) {
                in_index = base_index + h0 * W + w1;
                float val = __ldg(&input[in_index]);
                max_val = (val > max_val ? val : max_val);
            }
            // Check (h1, w0)
            if (h1 >= 0 && h1 < H && w0 >= 0 && w0 < W) {
                in_index = base_index + h1 * W + w0;
                float val = __ldg(&input[in_index]);
                max_val = (val > max_val ? val : max_val);
            }
            // Check (h1, w1)
            if (h1 >= 0 && h1 < H && w1 >= 0 && w1 < W) {
                in_index = base_index + h1 * W + w1;
                float val = __ldg(&input[in_index]);
                max_val = (val > max_val ? val : max_val);
            }
            output[index] = max_val;
        } 
        else {
            // Compute effective loop bounds to avoid per-element boundary checking.
            int i_start = 0;
            if (h_start < 0)
                i_start = (-h_start + dilation - 1) / dilation;
            int i_end = kernel_size;
            int tmp_h = H - 1 - h_start;
            if (tmp_h < 0)
                i_end = 0;
            else {
                int max_i = tmp_h / dilation + 1;
                if (max_i < i_end)
                    i_end = max_i;
            }
            
            int j_start = 0;
            if (w_start < 0)
                j_start = (-w_start + dilation - 1) / dilation;
            int j_end = kernel_size;
            int tmp_w = W - 1 - w_start;
            if (tmp_w < 0)
                j_end = 0;
            else {
                int max_j = tmp_w / dilation + 1;
                if (max_j < j_end)
                    j_end = max_j;
            }
            
            // Iterate only over valid indices with unrolling hints.
            #pragma unroll
            for (int i = i_start; i < i_end; i++) {
                int cur_h = h_start + i * dilation;
                #pragma unroll
                for (int j = j_start; j < j_end; j++) {
                    int cur_w = w_start + j * dilation;
                    int in_index = base_index + cur_h * W + cur_w;
                    float val = __ldg(&input[in_index]);
                    if (val > max_val)
                        max_val = val;
                }
            }
            output[index] = max_val;
        }
    }
}

torch::Tensor forward(torch::Tensor x,
                      int kernel_size,
                      int stride,
                      int padding,
                      int dilation) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // x shape: [batch, channels, H, W]
    int batch = x.size(0);
    int channels = x.size(1);
    int H = x.size(2);
    int W = x.size(3);
    
    int outH = compute_out_dim(H, padding, dilation, kernel_size, stride);
    int outW = compute_out_dim(W, padding, dilation, kernel_size, stride);
    
    auto output = torch::empty({batch, channels, outH, outW}, x.options());
    
    int total = batch * channels * outH * outW;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    maxpool2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        batch, channels,
        H, W,
        outH, outW,
        kernel_size,
        stride,
        padding,
        dilation
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused MaxPool2D custom CUDA kernel with specialization for kernel_size==2 and optimized generic branch");
}