1.44. 1D Average Pooling

Calculates a one-dimensional average pooling operation by computing the mean of values within a sliding window, using configurable kernel size, stride, and padding to aggregate portions of the input sequence.

Progress

Step 3 of 3
Reward
3.08x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize the 1D average pooling kernel by reducing division overhead and utilizing the read-only __ldg intrinsic in the valid-window branch. In most cases the pooling window is fully valid (i.e. completely within the input bounds), so we can precompute the reciprocal of kernel_size once and use it to avoid an expensive division for each output element. In the valid branch, we also use __ldg to fetch the input values from read-only cache. For boundary cases we still need to check each element. These optimizations reduce per-thread overhead and improve memory throughput. This approach genuinely speeds up the operation by avoiding redundant computations and utilizing hardware-friendly memory access, rather than just working around the operation.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Optimized CUDA kernel for 1D average pooling.
__global__ void avg_pool1d_kernel(const float* __restrict__ input,
                                  float* __restrict__ output,
                                  int batch_size,
                                  int channels,
                                  int input_length,
                                  int output_length,
                                  int kernel_size,
                                  int stride,
                                  int padding) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch_size * channels * output_length;
    if (index < total) {
        // Map linear index to (batch, channel, out_idx)
        int tmp = index;
        int out_idx = tmp % output_length;
        tmp /= output_length;
        int channel = tmp % channels;
        int batch = tmp / channels;

        // Compute starting index for the pooling window.
        int start = out_idx * stride - padding;
        int input_base = batch * channels * input_length + channel * input_length;
        float sum = 0.0f;
        float invKernel = 1.0f / kernel_size;

        // If the entire pooling window is within bounds, use optimized load with __ldg.
        if (start >= 0 && (start + kernel_size) <= input_length) {
            #pragma unroll
            for (int k = 0; k < kernel_size; k++) {
                sum += __ldg(&input[input_base + start + k]);
            }
        } else {
            // Otherwise, check each element for boundary conditions.
            for (int k = 0; k < kernel_size; k++) {
                int in_idx = start + k;
                if (in_idx >= 0 && in_idx < input_length) {
                    sum += input[input_base + in_idx];
                }
            }
        }
        output[index] = sum * invKernel;
    }
}

torch::Tensor forward(torch::Tensor x, int kernel_size, int stride, int padding) {
    // Enforce correct CUDA device usage.
    c10::cuda::CUDAGuard device_guard(x.device());

    // Get dimensions: x shape is [batch_size, channels, input_length]
    int batch_size = x.size(0);
    int channels = x.size(1);
    int input_length = x.size(2);

    // Calculate output length: floor((input_length + 2*padding - kernel_size) / stride) + 1.
    int output_length = (input_length + 2 * padding - kernel_size) / stride + 1;

    // Allocate output tensor.
    auto output = torch::empty({batch_size, channels, output_length}, x.options());

    int total = batch_size * channels * output_length;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    avg_pool1d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        batch_size,
        channels,
        input_length,
        output_length,
        kernel_size,
        stride,
        padding
    );
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "1D average pooling (optimized CUDA kernel with reciprocal division and __ldg)");
}