2.8. 3D Convolution with Pooling and Bias Summation
Performs a sequence of operations on volumetric data: it calculates a 3D convolution, divides the convolution output by a constant, applies max pooling, conducts a global average pooling to collapse spatial dimensions, adds a bias term, and finally sums the result along a specified dimension.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We design a custom fused CUDA operator to replace several chained operations in the original model’s forward pass. The original model applies a 3D convolution (which we keep unchanged using the highly‐optimized PyTorch operator), then divides its output by a constant, applies a 3D max pooling over a fixed window (pool_size), performs global average pooling (averaging over all spatial locations of the max pooled output), adds a bias term (with bias broadcast along spatial dims) and finally sums over a specific dimension (here, the channel dimension). Our fused CUDA operator will take as input the convolution “activation” tensor of shape (B, C, D, H, W) along with the constant divisor, the three pooling dimensions (pool_d, pool_h, pool_w), and the bias tensor. In a first kernel we fuse the element‐wise division and max pooling. For each (b,c) slice, we partition the spatial dimensions into non‐overlapping windows (of size pool_d×pool_h×pool_w), compute the maximum value in each window (after division by the divisor) and then average these maximum values (i.e. global average pooling over the pooled output). We then add the bias for channel c. This yields an intermediate “channel output” tensor of shape (B, C). In a second kernel we reduce (i.e. sum) the per‐channel value for each batch sample (summing along dimension 1) to get the final output (of shape (B)). We use shared memory within each block to perform parallel reductions. Note that we use c10::cuda::CUDAGuard to ensure the correct device is used, and we compile with -O3 and --use_fast_math extra cuda flags. This fused operator is then wrapped via a PyTorch extension and integrated within our Python nn.Module named “CUDAModel” which, in its forward pass, calls our custom CUDA operator after performing the convolution. --- FIX --- The error indicates that our fused operator returns a tensor of shape [B] (e.g., [128]) while the expected output shape is [B, 1, 1, 1]. Rather than modifying the CUDA kernel, we can resolve this in the Python layer by reshaping the output appropriately. In our CUDAModel.forward method, after calling fused_ops.fused_forward, we simply reshape the output using view or unsqueeze so that the final tensor has shape [batch_size, 1, 1, 1]. This fix conforms with the expected output of the original model.
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
namespace py = pybind11;
// Kernel 1: For each (B, C) slice, perform fused division and max pooling over non-overlapping windows,
// then compute the average of pooled values and add bias.
__global__ void fused_pool_kernel(
const float* __restrict__ input, // input tensor of shape (B, C, D, H, W)
float* __restrict__ channel_out, // output tensor of shape (B, C)
int B, int C, int D, int H, int W,
int pool_d, int pool_h, int pool_w,
float divisor,
const float* __restrict__ bias) {
// Each block processes one (b, c) pair.
int b = blockIdx.x;
int c = blockIdx.y;
// Compute the output spatial dimensions for pooling.
int out_D = D / pool_d;
int out_H = H / pool_h;
int out_W = W / pool_w;
int num_windows = out_D * out_H * out_W;
// Each thread in the block will handle multiple pooling windows.
int tid = threadIdx.x;
int num_threads = blockDim.x;
float partial_sum = 0.0f;
// Iterate over pooling windows in a strided manner.
for (int idx = tid; idx < num_windows; idx += num_threads) {
int i = idx / (out_H * out_W);
int rem = idx % (out_H * out_W);
int j = rem / out_W;
int k = rem % out_W;
int d_start = i * pool_d;
int h_start = j * pool_h;
int w_start = k * pool_w;
float max_val = -FLT_MAX;
// Loop over the pooling window.
for (int dd = 0; dd < pool_d; dd++) {
for (int hh = 0; hh < pool_h; hh++) {
for (int ww = 0; ww < pool_w; ww++) {
int d_idx = d_start + dd;
int h_idx = h_start + hh;
int w_idx = w_start + ww;
// Compute the flat index for input tensor.
// Input tensor layout: (B, C, D, H, W)
int index = (((b * C + c) * D + d_idx) * H + h_idx) * W + w_idx;
// Apply division and compute max.
float val = input[index] / divisor;
if (val > max_val)
max_val = val;
}
}
}
partial_sum += max_val;
}
// Reduction: sum partial sums across threads in the block.
extern __shared__ float sdata[];
sdata[tid] = partial_sum;
__syncthreads();
// Parallel reduction in shared memory.
for (unsigned int s = num_threads >> 1; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
// The first thread writes the averaged pooled value plus the bias.
if (tid == 0) {
float avg = sdata[0] / num_windows;
channel_out[b * C + c] = avg + bias[c];
}
}
// Kernel 2: For each batch sample, sum the per-channel values.
__global__ void sum_channels_kernel(
const float* __restrict__ channel_out, // tensor of shape (B, C)
float* __restrict__ output, // tensor of shape (B)
int B, int C) {
int b = blockIdx.x; // one block per sample.
int tid = threadIdx.x;
float sum_val = 0.0f;
// Sum over channels; since C is usually small, a simple loop is fine.
for (int c = tid; c < C; c += blockDim.x) {
sum_val += channel_out[b * C + c];
}
// Reduction in shared memory.
extern __shared__ float sdata[];
sdata[tid] = sum_val;
__syncthreads();
for (unsigned int s = blockDim.x >> 1; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[b] = sdata[0];
}
}
// fused_forward: the main function that launches the two kernels.
// It takes the input tensor from conv3d (shape: B, C, D, H, W),
// the divisor, pool dimensions, and the bias (as a flattened 1D tensor of length C).
// Returns a tensor of shape (B) which is the final result after summing over channels.
torch::Tensor fused_forward(
torch::Tensor input,
float divisor,
int pool_d, int pool_h, int pool_w,
torch::Tensor bias) {
// Ensure the operation runs on the correct device.
const c10::cuda::CUDAGuard device_guard(input.device());
// Get input dimensions.
int B = input.size(0);
int C = input.size(1);
int D = input.size(2);
int H = input.size(3);
int W = input.size(4);
// Allocate temporary tensor for per-channel outputs.
auto channel_out = torch::empty({B, C}, input.options());
// Allocate output tensor for final results (sum over channels).
auto output = torch::empty({B}, input.options());
// Launch Kernel 1: fused division, max pooling, global averaging and bias addition.
// Grid dimensions: (B, C), one block per (b,c).
dim3 grid_pool(B, C);
// Choose block size (e.g., 256 threads per block).
int block_size_pool = 256;
size_t shared_mem_pool = block_size_pool * sizeof(float);
fused_pool_kernel<<<grid_pool, block_size_pool, shared_mem_pool, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
channel_out.data_ptr<float>(),
B, C, D, H, W,
pool_d, pool_h, pool_w,
divisor,
bias.data_ptr<float>()
);
// Launch Kernel 2: Sum over channels for each batch sample.
dim3 grid_sum(B);
int block_size_sum = 256;
size_t shared_mem_sum = block_size_sum * sizeof(float);
sum_channels_kernel<<<grid_sum, block_size_sum, shared_mem_sum, c10::cuda::getCurrentCUDAStream()>>>(
channel_out.data_ptr<float>(),
output.data_ptr<float>(),
B, C
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_forward", &fused_forward, "Fused division, max pooling, global average pooling, bias addition, and channel summation");
}