1.46. 3D Average Pooling

Performs 3D average pooling on an input tensor by computing the mean over sliding windows, with configurable kernel size, stride, and padding to control the reduction of spatial dimensions.

Progress

Step 3 of 3
Reward
1.99x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our customized 3D average pooling kernel by unrolling the innermost loops when kernel_size equals 3 – a common case – to reduce loop overhead and redundant arithmetic. In addition, we can compute a base index for the input tensor for each output element to avoid repeatedly calculating the multiplications for the index. We preserve bounds‐checking for each pooling window element so that out‐of-bound accesses yield 0. These improvements reduce arithmetic operations and improve register reuse, which genuinely accelerates the kernel without cutting any corners in the mathematical computation.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Optimized 3D average pooling kernel.
// It computes the average over a kernel window for each output position.
// When kernel_size == 3, the inner loops are unrolled for 27 iterations.
// The kernel uses a computed base index to reduce redundant multiplications.
__global__ void avg_pool3d_kernel(const float* __restrict__ input,
                                  float* __restrict__ output,
                                  const int N, const int C,
                                  const int D, const int H, const int W,
                                  const int D_out, const int H_out, const int W_out,
                                  const int kernel_size, const int stride, const int padding) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * C * D_out * H_out * W_out;
    if (index < total) {
        // Compute output indices.
        int w_out = index % W_out;
        int h_out = (index / W_out) % H_out;
        int d_out = (index / (W_out * H_out)) % D_out;
        int c = (index / (W_out * H_out * D_out)) % C;
        int n = index / (W_out * H_out * D_out * C);
        
        // Compute starting indices in the input tensor (including padding).
        int d_start = d_out * stride - padding;
        int h_start = h_out * stride - padding;
        int w_start = w_out * stride - padding;
        
        // Compute base index for the current (n, c) slice.
        int base = (((n * C) + c) * D * H * W);
        
        float sum = 0.0f;
        if (kernel_size == 3) {
            // Precompute the indices in each dimension.
            int d0 = d_start;
            int d1 = d_start + 1;
            int d2 = d_start + 2;
            int h0 = h_start;
            int h1 = h_start + 1;
            int h2 = h_start + 2;
            int w0 = w_start;
            int w1 = w_start + 1;
            int w2 = w_start + 2;
            
            // For each of the 27 positions, if within bounds, add the corresponding element.
            if (d0 >= 0 && d0 < D) {
                int offset_d0 = d0 * H * W;
                if (h0 >= 0 && h0 < H) {
                    int offset_h0 = h0 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d0 + offset_h0 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d0 + offset_h0 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d0 + offset_h0 + w2];
                }
                if (h1 >= 0 && h1 < H) {
                    int offset_h1 = h1 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d0 + offset_h1 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d0 + offset_h1 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d0 + offset_h1 + w2];
                }
                if (h2 >= 0 && h2 < H) {
                    int offset_h2 = h2 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d0 + offset_h2 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d0 + offset_h2 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d0 + offset_h2 + w2];
                }
            }
            if (d1 >= 0 && d1 < D) {
                int offset_d1 = d1 * H * W;
                if (h0 >= 0 && h0 < H) {
                    int offset_h0 = h0 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d1 + offset_h0 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d1 + offset_h0 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d1 + offset_h0 + w2];
                }
                if (h1 >= 0 && h1 < H) {
                    int offset_h1 = h1 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d1 + offset_h1 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d1 + offset_h1 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d1 + offset_h1 + w2];
                }
                if (h2 >= 0 && h2 < H) {
                    int offset_h2 = h2 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d1 + offset_h2 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d1 + offset_h2 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d1 + offset_h2 + w2];
                }
            }
            if (d2 >= 0 && d2 < D) {
                int offset_d2 = d2 * H * W;
                if (h0 >= 0 && h0 < H) {
                    int offset_h0 = h0 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d2 + offset_h0 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d2 + offset_h0 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d2 + offset_h0 + w2];
                }
                if (h1 >= 0 && h1 < H) {
                    int offset_h1 = h1 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d2 + offset_h1 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d2 + offset_h1 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d2 + offset_h1 + w2];
                }
                if (h2 >= 0 && h2 < H) {
                    int offset_h2 = h2 * W;
                    if (w0 >= 0 && w0 < W) sum += input[base + offset_d2 + offset_h2 + w0];
                    if (w1 >= 0 && w1 < W) sum += input[base + offset_d2 + offset_h2 + w1];
                    if (w2 >= 0 && w2 < W) sum += input[base + offset_d2 + offset_h2 + w2];
                }
            }
            output[index] = sum / 27.0f;
        } else {
            // Generic solution for arbitrary kernel_size.
            for (int kd = 0; kd < kernel_size; kd++) {
                int d_in = d_start + kd;
                for (int kh = 0; kh < kernel_size; kh++) {
                    int h_in = h_start + kh;
                    for (int kw = 0; kw < kernel_size; kw++) {
                        int w_in = w_start + kw;
                        if (d_in >= 0 && d_in < D &&
                            h_in >= 0 && h_in < H &&
                            w_in >= 0 && w_in < W) {
                            int input_index = base + (d_in * H * W) + (h_in * W) + w_in;
                            sum += input[input_index];
                        }
                    }
                }
            }
            float pool_size = static_cast<float>(kernel_size * kernel_size * kernel_size);
            output[index] = sum / pool_size;
        }
    }
}

torch::Tensor forward(torch::Tensor x, int kernel_size, int stride, int padding) {
    c10::cuda::CUDAGuard device_guard(x.device());

    // x shape: [N, C, D, H, W]
    int N = x.size(0);
    int C = x.size(1);
    int D = x.size(2);
    int H = x.size(3);
    int W = x.size(4);

    // Compute output dimensions.
    int D_out = (D + 2 * padding - kernel_size) / stride + 1;
    int H_out = (H + 2 * padding - kernel_size) / stride + 1;
    int W_out = (W + 2 * padding - kernel_size) / stride + 1;

    // Allocate output tensor with shape [N, C, D_out, H_out, W_out].
    auto output = torch::empty({N, C, D_out, H_out, W_out}, x.options());

    int total = N * C * D_out * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;

    avg_pool3d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        N, C, D, H, W,
        D_out, H_out, W_out,
        kernel_size, stride, padding
    );
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Custom 3D Average Pooling (CUDA) with loop unrolling for kernel_size==3");
}