2.72. 3D Transposed Convolution, BatchNorm, and AvgPool Chain
Calculates a 3D transposed convolution on the input tensor with configurable kernel parameters, normalizes the result using batch normalization, and subsequently reduces spatial dimensions with two successive average pooling operations.
Progress
Optimization Timeline
Step 6 Details
Reasoning
The current implementation already achieves a measurable speedup by fusing the two AvgPool3d layers and by fully unrolling the fixed 4×4×4 pooling window with vectorized loads and alignment checks. To potentially squeeze a few more cycles, one could consider minor optimizations such as ensuring the pointer arithmetic is cast to volatile when performing vectorized loads to hint further caching if beneficial, or reordering instructions to maximize instruction level parallelism. However, such micro-optimizations may not always yield a significant practical gain compared to our current version. The current design is simple, fully unrolled, minimizes branch overhead, and uses vectorized loads with a fallback for misaligned addresses, which has already provided a reward factor greater than 1. Therefore, I will simply present the current design as the final optimized answer.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <stdint.h> // For uintptr_t
// Optimized fused average pooling kernel, fully unrolled for pool_size=4, stride=4.
// For each output element, the kernel averages a 4×4×4 block in the input tensor.
// For the innermost (width) dimension, it attempts a vectorized load (float4) if the address is 16-byte aligned,
// otherwise falling back to scalar loads.
__global__ void fused_avgpool_kernel(const float* __restrict__ input,
float* __restrict__ output,
int batch,
int channels,
int D_in, int H_in, int W_in,
int D_out, int H_out, int W_out) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * channels * D_out * H_out * W_out;
if (index < total) {
// Decode index into (b, c, d_out, h_out, w_out)
int w_out = index % W_out;
int tmp = index / W_out;
int h_out = tmp % H_out;
tmp /= H_out;
int d_out = tmp % D_out;
tmp /= D_out;
int c = tmp % channels;
int b = tmp / channels;
// Starting indices in the input tensor for the pooling window.
int d_start = d_out * 4;
int h_start = h_out * 4;
int w_start = w_out * 4;
float sum = 0.0f;
// Fully unrolled code for each 4x4 row in depth and height.
{
// kd=0, kh=0
int offset = (((b * channels + c) * D_in + (d_start + 0)) * H_in + (h_start + 0)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF) == 0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=0, kh=1
int offset = (((b * channels + c) * D_in + (d_start + 0)) * H_in + (h_start + 1)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF) == 0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=0, kh=2
int offset = (((b * channels + c) * D_in + (d_start + 0)) * H_in + (h_start + 2)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=0, kh=3
int offset = (((b * channels + c) * D_in + (d_start + 0)) * H_in + (h_start + 3)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=1, kh=0
int offset = (((b * channels + c) * D_in + (d_start + 1)) * H_in + (h_start + 0)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=1, kh=1
int offset = (((b * channels + c) * D_in + (d_start + 1)) * H_in + (h_start + 1)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=1, kh=2
int offset = (((b * channels + c) * D_in + (d_start + 1)) * H_in + (h_start + 2)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=1, kh=3
int offset = (((b * channels + c) * D_in + (d_start + 1)) * H_in + (h_start + 3)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=2, kh=0
int offset = (((b * channels + c) * D_in + (d_start + 2)) * H_in + (h_start + 0)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=2, kh=1
int offset = (((b * channels + c) * D_in + (d_start + 2)) * H_in + (h_start + 1)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=2, kh=2
int offset = (((b * channels + c) * D_in + (d_start + 2)) * H_in + (h_start + 2)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=2, kh=3
int offset = (((b * channels + c) * D_in + (d_start + 2)) * H_in + (h_start + 3)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=3, kh=0
int offset = (((b * channels + c) * D_in + (d_start + 3)) * H_in + (h_start + 0)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=3, kh=1
int offset = (((b * channels + c) * D_in + (d_start + 3)) * H_in + (h_start + 1)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=3, kh=2
int offset = (((b * channels + c) * D_in + (d_start + 3)) * H_in + (h_start + 2)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
{
// kd=3, kh=3
int offset = (((b * channels + c) * D_in + (d_start + 3)) * H_in + (h_start + 3)) * W_in + w_start;
const float* ptr = &input[offset];
bool aligned = ((((uintptr_t)ptr) & 0xF)==0);
if(aligned) {
float4 v = *reinterpret_cast<const float4*>(ptr);
sum += v.x + v.y + v.z + v.w;
} else {
sum += ptr[0] + ptr[1] + ptr[2] + ptr[3];
}
}
output[index] = sum * (1.0f / 64.0f);
}
}
torch::Tensor forward(torch::Tensor input, int pool_size, int stride) {
// For this kernel, pool_size and stride are assumed to be 4.
c10::cuda::CUDAGuard device_guard(input.device());
// Input shape: [batch, channels, D_in, H_in, W_in]
int batch = input.size(0);
int channels = input.size(1);
int D_in = input.size(2);
int H_in = input.size(3);
int W_in = input.size(4);
// Compute output dimensions.
int D_out = (D_in - pool_size) / stride + 1;
int H_out = (H_in - pool_size) / stride + 1;
int W_out = (W_in - pool_size) / stride + 1;
auto output = torch::empty({batch, channels, D_out, H_out, W_out}, input.options());
int total = batch * channels * D_out * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_avgpool_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
batch, channels,
D_in, H_in, W_in,
D_out, H_out, W_out
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused 3D average pooling with pool_size=4 and stride=4 (fully unrolled, vectorized with alignment check)");
}