1.42. 2D Max Pooling
Performs a 2D max pooling operation that calculates the maximum value within each pooling window on a 4D input tensor. The computation uses adjustable parameters including kernel size, stride, padding, and dilation to control the pooling process.
Progress
Step 6 of 6
Reward
1.69x
Optimization Timeline
Step 6 Details
Reasoning
We can further improve the kernel by using the __ldg intrinsic for all global memory reads to take advantage of the read‐only cache. In the generic branch, we also add #pragma unroll hints to unroll the inner loops since the pooling window is generally small. These improvements will reduce instruction overhead and memory latency. The Python module stays the same.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
// Compute output dimension using the standard pooling formula.
__host__ __device__ inline int compute_out_dim(int size, int padding, int dilation, int kernel_size, int stride) {
return ((size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1;
}
__global__ void maxpool2d_kernel(const float* __restrict__ input,
float* __restrict__ output,
int batch,
int channels,
int H, int W,
int outH, int outW,
int kernel_size,
int stride,
int padding,
int dilation) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
const int total = batch * channels * outH * outW;
if (index < total) {
// Determine output coordinates.
int w_out = index % outW;
int tmp = index / outW;
int h_out = tmp % outH;
tmp = tmp / outH;
int c = tmp % channels;
int n = tmp / channels;
// Compute starting coordinates in the input.
int h_start = h_out * stride - padding;
int w_start = w_out * stride - padding;
// Base index for the (n, c) slice.
int base_index = n * (channels * H * W) + c * (H * W);
float max_val = -FLT_MAX;
// Specialized branch for kernel_size == 2.
if (kernel_size == 2) {
int h0 = h_start;
int h1 = h_start + dilation;
int w0 = w_start;
int w1 = w_start + dilation;
int in_index;
// Check (h0, w0)
if (h0 >= 0 && h0 < H && w0 >= 0 && w0 < W) {
in_index = base_index + h0 * W + w0;
max_val = __ldg(&input[in_index]);
}
// Check (h0, w1)
if (h0 >= 0 && h0 < H && w1 >= 0 && w1 < W) {
in_index = base_index + h0 * W + w1;
float val = __ldg(&input[in_index]);
max_val = (val > max_val ? val : max_val);
}
// Check (h1, w0)
if (h1 >= 0 && h1 < H && w0 >= 0 && w0 < W) {
in_index = base_index + h1 * W + w0;
float val = __ldg(&input[in_index]);
max_val = (val > max_val ? val : max_val);
}
// Check (h1, w1)
if (h1 >= 0 && h1 < H && w1 >= 0 && w1 < W) {
in_index = base_index + h1 * W + w1;
float val = __ldg(&input[in_index]);
max_val = (val > max_val ? val : max_val);
}
output[index] = max_val;
}
else {
// Compute effective loop bounds to avoid per-element boundary checking.
int i_start = 0;
if (h_start < 0)
i_start = (-h_start + dilation - 1) / dilation;
int i_end = kernel_size;
int tmp_h = H - 1 - h_start;
if (tmp_h < 0)
i_end = 0;
else {
int max_i = tmp_h / dilation + 1;
if (max_i < i_end)
i_end = max_i;
}
int j_start = 0;
if (w_start < 0)
j_start = (-w_start + dilation - 1) / dilation;
int j_end = kernel_size;
int tmp_w = W - 1 - w_start;
if (tmp_w < 0)
j_end = 0;
else {
int max_j = tmp_w / dilation + 1;
if (max_j < j_end)
j_end = max_j;
}
// Iterate only over valid indices with unrolling hints.
#pragma unroll
for (int i = i_start; i < i_end; i++) {
int cur_h = h_start + i * dilation;
#pragma unroll
for (int j = j_start; j < j_end; j++) {
int cur_w = w_start + j * dilation;
int in_index = base_index + cur_h * W + cur_w;
float val = __ldg(&input[in_index]);
if (val > max_val)
max_val = val;
}
}
output[index] = max_val;
}
}
}
torch::Tensor forward(torch::Tensor x,
int kernel_size,
int stride,
int padding,
int dilation) {
c10::cuda::CUDAGuard device_guard(x.device());
// x shape: [batch, channels, H, W]
int batch = x.size(0);
int channels = x.size(1);
int H = x.size(2);
int W = x.size(3);
int outH = compute_out_dim(H, padding, dilation, kernel_size, stride);
int outW = compute_out_dim(W, padding, dilation, kernel_size, stride);
auto output = torch::empty({batch, channels, outH, outW}, x.options());
int total = batch * channels * outH * outW;
int threads = 256;
int blocks = (total + threads - 1) / threads;
maxpool2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
output.data_ptr<float>(),
batch, channels,
H, W,
outH, outW,
kernel_size,
stride,
padding,
dilation
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused MaxPool2D custom CUDA kernel with specialization for kernel_size==2 and optimized generic branch");
}