2.65. Conv2d_AvgPool_Sigmoid_Sum

Computes a 2D convolution on the input, applies average pooling to reduce spatial dimensions, transforms the result with a sigmoid activation, and finally sums across all non-batch dimensions.

Progress

Step 2 of 2
Reward
1.75x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused kernel by specializing for the common parameters kernel_size=3 and pool_kernel_size=2. In our previous solution, the generic kernel used nested loops for pooling and convolution. For many vision models these parameters are fixed and small. By “unrolling” the inner loops for a 3×3 convolution and 2×2 pooling, we can reduce loop overhead and enable better compiler optimizations. In the specialized kernel we:  • Check that kernel_size==3 and pool_kernel_size==2.  • Unroll the two loops over the pooling window (which has 4 elements).  • Unroll the two inner loops over the 3×3 convolution filter for each input channel. When these conditions hold, we dispatch the specialized kernel; otherwise, we fall back to the generic implementation. We use c10::cuda::CUDAGuard and getCurrentCUDAStream() for proper device and stream management. Have my improvements genuinely helped? Yes. Unrolling small loops (for common 3×3 convolution and 2×2 pooling) reduces overhead and improves throughput without “cheating” – it is a standard optimization. Below is the final Python module (unchanged) and the updated CUDA kernel with both specialized and generic versions.

// kernel.cu
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

#define THREADS_PER_BLOCK 256

// Generic fused kernel for arbitrary kernel_size and pool_kernel_size.
__global__ void fused_kernel_generic(const float* __restrict__ x,
                                     const float* __restrict__ conv_weight,
                                     const float* __restrict__ conv_bias,
                                     float* __restrict__ output,
                                     int in_channels,
                                     int out_channels,
                                     int H, int W,
                                     int kernel_size,
                                     int pool_kernel_size) {
    int b = blockIdx.x; // batch index

    // Convolution output dimensions.
    int H_conv = H - kernel_size + 1;
    int W_conv = W - kernel_size + 1;
    // Assume H_conv and W_conv are divisible by pool_kernel_size.
    int H_pool = H_conv / pool_kernel_size;
    int W_pool = W_conv / pool_kernel_size;
    int total_cells = out_channels * H_pool * W_pool;
    
    float local_sum = 0.0f;
    // Strided loop: each thread processes several pooled cells.
    for (int cell = threadIdx.x; cell < total_cells; cell += blockDim.x) {
        // Decode pooled cell index into: output channel, and pooling window location.
        int oc = cell / (H_pool * W_pool);
        int rem = cell % (H_pool * W_pool);
        int i_pool = rem / W_pool;
        int j_pool = rem % W_pool;
        int i_conv_start = i_pool * pool_kernel_size;
        int j_conv_start = j_pool * pool_kernel_size;
        
        float window_sum = 0.0f;
        // Loop over pooling window.
        for (int pr = 0; pr < pool_kernel_size; pr++) {
            for (int pc = 0; pc < pool_kernel_size; pc++) {
                int i_conv = i_conv_start + pr;
                int j_conv = j_conv_start + pc;
                
                float conv_val = conv_bias[oc];
                // Sum over input channels and kernel window.
                for (int ic = 0; ic < in_channels; ic++) {
                    for (int u = 0; u < kernel_size; u++) {
                        for (int v = 0; v < kernel_size; v++) {
                            int x_idx = b * (in_channels * H * W) + ic * (H * W) + (i_conv + u) * W + (j_conv + v);
                            int w_idx = oc * (in_channels * kernel_size * kernel_size) + ic * (kernel_size * kernel_size) + u * kernel_size + v;
                            conv_val += x[x_idx] * conv_weight[w_idx];
                        }
                    }
                }
                window_sum += conv_val;
            }
        }
        float pooled = window_sum / (pool_kernel_size * pool_kernel_size);
        float activated = 1.0f / (1.0f + expf(-pooled));
        local_sum += activated;
    }
    // Reduction across threads in the block.
    __shared__ float sdata[THREADS_PER_BLOCK];
    int tid = threadIdx.x;
    sdata[tid] = local_sum;
    __syncthreads();
    for (int s = blockDim.x/2; s > 0; s >>= 1) {
        if (tid < s)
            sdata[tid] += sdata[tid+s];
        __syncthreads();
    }
    if(tid == 0) {
        output[b] = sdata[0];
    }
}

// Specialized fused kernel for kernel_size==3 and pool_kernel_size==2.
__global__ void fused_kernel_specialized(const float* __restrict__ x,
                                         const float* __restrict__ conv_weight,
                                         const float* __restrict__ conv_bias,
                                         float* __restrict__ output,
                                         int in_channels,
                                         int out_channels,
                                         int H, int W) {
    int b = blockIdx.x; // batch index

    // For kernel_size==3: H_conv = H - 2, W_conv = W - 2.
    int H_conv = H - 3 + 1; // H - 2
    int W_conv = W - 3 + 1; // W - 2
    // With pool_kernel_size==2.
    int pool_kernel_size = 2;
    int H_pool = H_conv / pool_kernel_size;
    int W_pool = W_conv / pool_kernel_size;
    int total_cells = out_channels * H_pool * W_pool;
    
    float local_sum = 0.0f;
    
    for (int cell = threadIdx.x; cell < total_cells; cell += blockDim.x) {
        // Decode cell index.
        int oc = cell / (H_pool * W_pool);
        int rem = cell % (H_pool * W_pool);
        int i_pool = rem / W_pool;
        int j_pool = rem % W_pool;
        int i_conv_start = i_pool * pool_kernel_size;
        int j_conv_start = j_pool * pool_kernel_size;
        
        float window_sum = 0.0f;
        // Unrolled loops for each element in the 2x2 pooling window.
        #pragma unroll
        for (int pr = 0; pr < 2; pr++) {
            #pragma unroll
            for (int pc = 0; pc < 2; pc++) {
                int i_conv = i_conv_start + pr;
                int j_conv = j_conv_start + pc;
                
                float conv_val = conv_bias[oc];
                // Loop over input channels.
                for (int ic = 0; ic < in_channels; ic++) {
                    // Calculate base indices for x and weight.
                    int base_x = b * (in_channels * H * W) + ic * (H * W);
                    int base_w = oc * (in_channels * 9) + ic * 9;
                    // Unrolled 3x3 convolution.
                    conv_val += x[base_x + (i_conv+0)*W + (j_conv+0)] * conv_weight[base_w + 0];
                    conv_val += x[base_x + (i_conv+0)*W + (j_conv+1)] * conv_weight[base_w + 1];
                    conv_val += x[base_x + (i_conv+0)*W + (j_conv+2)] * conv_weight[base_w + 2];
                    conv_val += x[base_x + (i_conv+1)*W + (j_conv+0)] * conv_weight[base_w + 3];
                    conv_val += x[base_x + (i_conv+1)*W + (j_conv+1)] * conv_weight[base_w + 4];
                    conv_val += x[base_x + (i_conv+1)*W + (j_conv+2)] * conv_weight[base_w + 5];
                    conv_val += x[base_x + (i_conv+2)*W + (j_conv+0)] * conv_weight[base_w + 6];
                    conv_val += x[base_x + (i_conv+2)*W + (j_conv+1)] * conv_weight[base_w + 7];
                    conv_val += x[base_x + (i_conv+2)*W + (j_conv+2)] * conv_weight[base_w + 8];
                }
                window_sum += conv_val;
            }
        }
        float pooled = window_sum / 4.0f;  // 2x2 pooling window
        float activated = 1.0f / (1.0f + expf(-pooled));
        local_sum += activated;
    }
    
    __shared__ float sdata[THREADS_PER_BLOCK];
    int tid = threadIdx.x;
    sdata[tid] = local_sum;
    __syncthreads();
    for (int s = blockDim.x/2; s > 0; s >>= 1) {
        if (tid < s)
            sdata[tid] += sdata[tid+s];
        __syncthreads();
    }
    if(tid == 0) {
        output[b] = sdata[0];
    }
}

// Forward function: selects specialized kernel if parameters are 3 and 2, else generic.
torch::Tensor forward(torch::Tensor x,
                      torch::Tensor conv_weight,
                      torch::Tensor conv_bias,
                      int pool_kernel_size) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // x shape: [B, in_channels, H, W]
    int B = x.size(0);
    int H = x.size(2);
    int W = x.size(3);
    
    int kernel_size = conv_weight.size(2); // assuming square kernel
    int in_channels = conv_weight.size(1);
    int out_channels = conv_weight.size(0);
    
    // Compute convolution output dimensions.
    int H_conv = H - kernel_size + 1;
    int W_conv = W - kernel_size + 1;
    
    // Check divisibility for pooling.
    int H_pool = H_conv / pool_kernel_size;
    int W_pool = W_conv / pool_kernel_size;
    
    auto output = torch::empty({B}, x.options());
    dim3 grid(B);
    dim3 block(THREADS_PER_BLOCK);
    
    if (kernel_size == 3 && pool_kernel_size == 2) {
        fused_kernel_specialized<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
            x.data_ptr<float>(),
            conv_weight.data_ptr<float>(),
            conv_bias.data_ptr<float>(),
            output.data_ptr<float>(),
            in_channels,
            out_channels,
            H,
            W
        );
    } else {
        fused_kernel_generic<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
            x.data_ptr<float>(),
            conv_weight.data_ptr<float>(),
            conv_bias.data_ptr<float>(),
            output.data_ptr<float>(),
            in_channels,
            out_channels,
            H,
            W,
            kernel_size,
            pool_kernel_size
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused conv2d, avg pool, sigmoid, and sum reduction (with specialization)");
}