2.85. Conv2d GroupNorm Scale MaxPool Clamp

Performs a 2D convolution on the input, applies group normalization to standardize the resulting features, scales the normalized output with a learnable parameter, downsamples the features via max pooling, and finally clamps the resulting values within a specified range.

Progress

Step 2 of 2
Reward
1.22x

Optimization Timeline

Step 2 Details

Reasoning

We can squeeze further performance improvements into our fused max pooling and clamping kernel. In the original kernel, every thread computes the maximum value in its pooling window using nested loops. Since the most common pooling kernel size is 2, we can write a specialized version for pool_size==2 that avoids loops and uses vectorized memory loads with __ldg. We also add "#pragma unroll" directives in the generic kernel to aid the compiler when pool_size is small. In the host (forward) function, we check if the pool size equals 2 and dispatch the specialized kernel; otherwise, we use the generic one. These improvements reduce the instruction count per thread and improve memory access performance, offering a genuine speedup.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <float.h>

// Generic kernel for any pool size.
// Uses loop unrolling directives to help the compiler when pool_size is small.
__global__ void maxpool_clamp_kernel(const float* __restrict__ input,
                                     float* __restrict__ output,
                                     int B, int C, int H, int W,
                                     int pool_size,
                                     float clamp_min, float clamp_max) {
    int outH = H / pool_size;
    int outW = W / pool_size;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * C * outH * outW;
    if (index < total) {
        int w = index % outW;
        int h = (index / outW) % outH;
        int temp = index / (outW * outH);
        int c = temp % C;
        int b = temp / C;
        int in_row_start = h * pool_size;
        int in_col_start = w * pool_size;
        
        float max_val = -FLT_MAX;
        for (int i = 0; i < pool_size; i++) {
            #pragma unroll
            for (int j = 0; j < pool_size; j++) {
                int in_row = in_row_start + i;
                int in_col = in_col_start + j;
                int in_index = ((b * C + c) * H + in_row) * W + in_col;
                float val = __ldg(&input[in_index]);
                if (val > max_val) {
                    max_val = val;
                }
            }
        }
        // Clamp the result.
        max_val = fmaxf(max_val, clamp_min);
        max_val = fminf(max_val, clamp_max);
        output[index] = max_val;
    }
}

// Specialized kernel for pool_size == 2.
__global__ void maxpool_clamp_kernel_2(const float* __restrict__ input,
                                       float* __restrict__ output,
                                       int B, int C, int H, int W,
                                       float clamp_min, float clamp_max) {
    int outH = H / 2;
    int outW = W / 2;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * C * outH * outW;
    if (index < total) {
        int w = index % outW;
        int h = (index / outW) % outH;
        int temp = index / (outW * outH);
        int c = temp % C;
        int b = temp / C;
        int base = ((b * C + c) * H + h * 2) * W + w * 2;
        float v0 = __ldg(&input[base]);
        float v1 = __ldg(&input[base + 1]);
        float v2 = __ldg(&input[base + W]);
        float v3 = __ldg(&input[base + W + 1]);
        float max_val = fmaxf(fmaxf(v0, v1), fmaxf(v2, v3));
        max_val = fmaxf(max_val, clamp_min);
        max_val = fminf(max_val, clamp_max);
        output[index] = max_val;
    }
}

torch::Tensor forward(torch::Tensor input, int pool_size, float clamp_min, float clamp_max) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Input shape: (B, C, H, W)
    int B = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);

    // Ensure that H and W are divisible by pool_size.
    TORCH_CHECK(H % pool_size == 0, "Input height must be divisible by pool_size");
    TORCH_CHECK(W % pool_size == 0, "Input width must be divisible by pool_size");
    
    int outH = H / pool_size;
    int outW = W / pool_size;
    auto output = torch::empty({B, C, outH, outW}, input.options());
    
    int total = B * C * outH * outW;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    if (pool_size == 2) {
        maxpool_clamp_kernel_2<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            output.data_ptr<float>(),
            B, C, H, W,
            clamp_min, clamp_max
        );
    } else {
        maxpool_clamp_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            output.data_ptr<float>(),
            B, C, H, W,
            pool_size,
            clamp_min, clamp_max
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused max pooling and clamping kernel (with specialized version for pool_size==2)");
}