2.65. Conv2d_AvgPool_Sigmoid_Sum
Computes a 2D convolution on the input, applies average pooling to reduce spatial dimensions, transforms the result with a sigmoid activation, and finally sums across all non-batch dimensions.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused kernel by specializing for the common parameters kernel_size=3 and pool_kernel_size=2. In our previous solution, the generic kernel used nested loops for pooling and convolution. For many vision models these parameters are fixed and small. By “unrolling” the inner loops for a 3×3 convolution and 2×2 pooling, we can reduce loop overhead and enable better compiler optimizations. In the specialized kernel we: • Check that kernel_size==3 and pool_kernel_size==2. • Unroll the two loops over the pooling window (which has 4 elements). • Unroll the two inner loops over the 3×3 convolution filter for each input channel. When these conditions hold, we dispatch the specialized kernel; otherwise, we fall back to the generic implementation. We use c10::cuda::CUDAGuard and getCurrentCUDAStream() for proper device and stream management. Have my improvements genuinely helped? Yes. Unrolling small loops (for common 3×3 convolution and 2×2 pooling) reduces overhead and improves throughput without “cheating” – it is a standard optimization. Below is the final Python module (unchanged) and the updated CUDA kernel with both specialized and generic versions.
// kernel.cu
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
#define THREADS_PER_BLOCK 256
// Generic fused kernel for arbitrary kernel_size and pool_kernel_size.
__global__ void fused_kernel_generic(const float* __restrict__ x,
const float* __restrict__ conv_weight,
const float* __restrict__ conv_bias,
float* __restrict__ output,
int in_channels,
int out_channels,
int H, int W,
int kernel_size,
int pool_kernel_size) {
int b = blockIdx.x; // batch index
// Convolution output dimensions.
int H_conv = H - kernel_size + 1;
int W_conv = W - kernel_size + 1;
// Assume H_conv and W_conv are divisible by pool_kernel_size.
int H_pool = H_conv / pool_kernel_size;
int W_pool = W_conv / pool_kernel_size;
int total_cells = out_channels * H_pool * W_pool;
float local_sum = 0.0f;
// Strided loop: each thread processes several pooled cells.
for (int cell = threadIdx.x; cell < total_cells; cell += blockDim.x) {
// Decode pooled cell index into: output channel, and pooling window location.
int oc = cell / (H_pool * W_pool);
int rem = cell % (H_pool * W_pool);
int i_pool = rem / W_pool;
int j_pool = rem % W_pool;
int i_conv_start = i_pool * pool_kernel_size;
int j_conv_start = j_pool * pool_kernel_size;
float window_sum = 0.0f;
// Loop over pooling window.
for (int pr = 0; pr < pool_kernel_size; pr++) {
for (int pc = 0; pc < pool_kernel_size; pc++) {
int i_conv = i_conv_start + pr;
int j_conv = j_conv_start + pc;
float conv_val = conv_bias[oc];
// Sum over input channels and kernel window.
for (int ic = 0; ic < in_channels; ic++) {
for (int u = 0; u < kernel_size; u++) {
for (int v = 0; v < kernel_size; v++) {
int x_idx = b * (in_channels * H * W) + ic * (H * W) + (i_conv + u) * W + (j_conv + v);
int w_idx = oc * (in_channels * kernel_size * kernel_size) + ic * (kernel_size * kernel_size) + u * kernel_size + v;
conv_val += x[x_idx] * conv_weight[w_idx];
}
}
}
window_sum += conv_val;
}
}
float pooled = window_sum / (pool_kernel_size * pool_kernel_size);
float activated = 1.0f / (1.0f + expf(-pooled));
local_sum += activated;
}
// Reduction across threads in the block.
__shared__ float sdata[THREADS_PER_BLOCK];
int tid = threadIdx.x;
sdata[tid] = local_sum;
__syncthreads();
for (int s = blockDim.x/2; s > 0; s >>= 1) {
if (tid < s)
sdata[tid] += sdata[tid+s];
__syncthreads();
}
if(tid == 0) {
output[b] = sdata[0];
}
}
// Specialized fused kernel for kernel_size==3 and pool_kernel_size==2.
__global__ void fused_kernel_specialized(const float* __restrict__ x,
const float* __restrict__ conv_weight,
const float* __restrict__ conv_bias,
float* __restrict__ output,
int in_channels,
int out_channels,
int H, int W) {
int b = blockIdx.x; // batch index
// For kernel_size==3: H_conv = H - 2, W_conv = W - 2.
int H_conv = H - 3 + 1; // H - 2
int W_conv = W - 3 + 1; // W - 2
// With pool_kernel_size==2.
int pool_kernel_size = 2;
int H_pool = H_conv / pool_kernel_size;
int W_pool = W_conv / pool_kernel_size;
int total_cells = out_channels * H_pool * W_pool;
float local_sum = 0.0f;
for (int cell = threadIdx.x; cell < total_cells; cell += blockDim.x) {
// Decode cell index.
int oc = cell / (H_pool * W_pool);
int rem = cell % (H_pool * W_pool);
int i_pool = rem / W_pool;
int j_pool = rem % W_pool;
int i_conv_start = i_pool * pool_kernel_size;
int j_conv_start = j_pool * pool_kernel_size;
float window_sum = 0.0f;
// Unrolled loops for each element in the 2x2 pooling window.
#pragma unroll
for (int pr = 0; pr < 2; pr++) {
#pragma unroll
for (int pc = 0; pc < 2; pc++) {
int i_conv = i_conv_start + pr;
int j_conv = j_conv_start + pc;
float conv_val = conv_bias[oc];
// Loop over input channels.
for (int ic = 0; ic < in_channels; ic++) {
// Calculate base indices for x and weight.
int base_x = b * (in_channels * H * W) + ic * (H * W);
int base_w = oc * (in_channels * 9) + ic * 9;
// Unrolled 3x3 convolution.
conv_val += x[base_x + (i_conv+0)*W + (j_conv+0)] * conv_weight[base_w + 0];
conv_val += x[base_x + (i_conv+0)*W + (j_conv+1)] * conv_weight[base_w + 1];
conv_val += x[base_x + (i_conv+0)*W + (j_conv+2)] * conv_weight[base_w + 2];
conv_val += x[base_x + (i_conv+1)*W + (j_conv+0)] * conv_weight[base_w + 3];
conv_val += x[base_x + (i_conv+1)*W + (j_conv+1)] * conv_weight[base_w + 4];
conv_val += x[base_x + (i_conv+1)*W + (j_conv+2)] * conv_weight[base_w + 5];
conv_val += x[base_x + (i_conv+2)*W + (j_conv+0)] * conv_weight[base_w + 6];
conv_val += x[base_x + (i_conv+2)*W + (j_conv+1)] * conv_weight[base_w + 7];
conv_val += x[base_x + (i_conv+2)*W + (j_conv+2)] * conv_weight[base_w + 8];
}
window_sum += conv_val;
}
}
float pooled = window_sum / 4.0f; // 2x2 pooling window
float activated = 1.0f / (1.0f + expf(-pooled));
local_sum += activated;
}
__shared__ float sdata[THREADS_PER_BLOCK];
int tid = threadIdx.x;
sdata[tid] = local_sum;
__syncthreads();
for (int s = blockDim.x/2; s > 0; s >>= 1) {
if (tid < s)
sdata[tid] += sdata[tid+s];
__syncthreads();
}
if(tid == 0) {
output[b] = sdata[0];
}
}
// Forward function: selects specialized kernel if parameters are 3 and 2, else generic.
torch::Tensor forward(torch::Tensor x,
torch::Tensor conv_weight,
torch::Tensor conv_bias,
int pool_kernel_size) {
c10::cuda::CUDAGuard device_guard(x.device());
// x shape: [B, in_channels, H, W]
int B = x.size(0);
int H = x.size(2);
int W = x.size(3);
int kernel_size = conv_weight.size(2); // assuming square kernel
int in_channels = conv_weight.size(1);
int out_channels = conv_weight.size(0);
// Compute convolution output dimensions.
int H_conv = H - kernel_size + 1;
int W_conv = W - kernel_size + 1;
// Check divisibility for pooling.
int H_pool = H_conv / pool_kernel_size;
int W_pool = W_conv / pool_kernel_size;
auto output = torch::empty({B}, x.options());
dim3 grid(B);
dim3 block(THREADS_PER_BLOCK);
if (kernel_size == 3 && pool_kernel_size == 2) {
fused_kernel_specialized<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
conv_weight.data_ptr<float>(),
conv_bias.data_ptr<float>(),
output.data_ptr<float>(),
in_channels,
out_channels,
H,
W
);
} else {
fused_kernel_generic<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
conv_weight.data_ptr<float>(),
conv_bias.data_ptr<float>(),
output.data_ptr<float>(),
in_channels,
out_channels,
H,
W,
kernel_size,
pool_kernel_size
);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused conv2d, avg pool, sigmoid, and sum reduction (with specialization)");
}