3.13. DenseNet Transition Operation
Calculates a transition operation that normalizes input feature maps, applies a ReLU activation, performs a 1x1 convolution to adjust the channel dimensions, and then downsamples the spatial resolution using average pooling.
Progress
Optimization Timeline
Step 3 Details
Reasoning
Our goal is to further optimize the fused kernel that implements BatchNorm2d, ReLU, 1×1 Conv2d, and AvgPool2d. In our earlier version, we precomputed the BN “scale” and “shift” to avoid repeatedly calculating sqrt and division. However, our inner loop still iterates over the pooling window (2×2) within an inner loop over the input channels. Since the pooling window is a constant 2×2, we can unroll that loop manually. This avoids the overhead of the two nested loops for pooling and improves register reuse. Additionally, by caching the common offset computations, we can further reduce arithmetic overhead. We also employ __ldg to load read-only values such as scale, shift, and convolution weights. We must compute the four positions corresponding to a given output (i,j) of the pooling result. We decode b, oc, i, j from the linear thread index, compute base indices for the 2×2 pooling window, and in a single loop over channels (ic), we accumulate the 1×1 convolution results for each pooling position (each position derives from one input pixel). After summing the four results, we average them (multiply by 0.25f) to yield the output pixel value. We add "#pragma unroll" directives for the channel loop (if possible) to suggest loop unrolling. These improvements should genuinely reduce per–output computation time and lower overall runtime. The Python module remains the same interface, while the CUDA kernel is updated below.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel performs:
// For each output pixel (b, oc, i, j) in the avg-pooled result:
// It computes input pooling window positions from a 1x1 conv on x with BN+ReLU fused.
// It uses precomputed scale and shift arrays to apply BatchNorm, and then applies ReLU,
// multiplies by conv_weight (for a 1x1 conv), and averages over a 2x2 window.
// Input: x shape [B, C_in, H, W]
// Bin parameters: scale, shift of shape [C_in]
// Conv weight: conv_weight shape [C_out, C_in]
// Output shape: [B, C_out, H/2, W/2]
__global__ void fused_kernel(const float* __restrict__ x,
const float* __restrict__ scale, // [C_in]
const float* __restrict__ shift, // [C_in]
const float* __restrict__ conv_weight, // [C_out, C_in]
float* __restrict__ out,
int B, int C_in, int H, int W, int C_out) {
int H_out = H / 2;
int W_out = W / 2;
// Each thread computes one output element at (b, oc, i, j)
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * C_out * H_out * W_out;
if (index >= total) return;
// Decode linear index to (b, oc, i, j)
int j = index % W_out;
int tmp = index / W_out;
int i = tmp % H_out;
tmp = tmp / H_out;
int oc = tmp % C_out;
int b = tmp / C_out;
// Compute base indices for pooling window (2x2)
int base_h = 2 * i;
int base_w = 2 * j;
float conv_sum0 = 0.0f;
float conv_sum1 = 0.0f;
float conv_sum2 = 0.0f;
float conv_sum3 = 0.0f;
// Precompute base offset for batch b in input x.
int base_offset = b * C_in * H * W;
// Loop over input channels; unroll if possible.
#pragma unroll
for (int ic = 0; ic < C_in; ic++) {
// Compute offset for channel ic.
int offset = base_offset + ic * H * W;
// Load scale and shift once from read-only cache.
float s = __ldg(&scale[ic]);
float sh = __ldg(&shift[ic]);
// Load conv weight for current output channel and ic.
int weight_index = oc * C_in + ic;
float w_val = __ldg(&conv_weight[weight_index]);
// For pooling window:
// Position (base_h, base_w)
int idx0 = offset + base_h * W + base_w;
float x0 = x[idx0];
float bn0 = s * x0 + sh;
float relu0 = bn0 > 0.0f ? bn0 : 0.0f;
conv_sum0 += w_val * relu0;
// Position (base_h, base_w+1)
int idx1 = offset + base_h * W + (base_w + 1);
float x1 = x[idx1];
float bn1 = s * x1 + sh;
float relu1 = bn1 > 0.0f ? bn1 : 0.0f;
conv_sum1 += w_val * relu1;
// Position (base_h+1, base_w)
int idx2 = offset + (base_h + 1) * W + base_w;
float x2 = x[idx2];
float bn2 = s * x2 + sh;
float relu2 = bn2 > 0.0f ? bn2 : 0.0f;
conv_sum2 += w_val * relu2;
// Position (base_h+1, base_w+1)
int idx3 = offset + (base_h + 1) * W + (base_w + 1);
float x3 = x[idx3];
float bn3 = s * x3 + sh;
float relu3 = bn3 > 0.0f ? bn3 : 0.0f;
conv_sum3 += w_val * relu3;
}
// Sum pooling results from 4 positions and average.
float avg = (conv_sum0 + conv_sum1 + conv_sum2 + conv_sum3) * 0.25f;
// Write output at (b, oc, i, j)
int out_index = ((b * C_out + oc) * H_out + i) * W_out + j;
out[out_index] = avg;
}
torch::Tensor forward(torch::Tensor x,
torch::Tensor scale,
torch::Tensor shift,
torch::Tensor conv_weight) {
// Ensure execution on proper device.
c10::cuda::CUDAGuard device_guard(x.device());
int B = x.size(0);
int C_in = x.size(1);
int H = x.size(2);
int W = x.size(3);
int C_out = conv_weight.size(0);
int H_out = H / 2;
int W_out = W / 2;
auto output = torch::empty({B, C_out, H_out, W_out}, x.options());
int total = B * C_out * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
scale.data_ptr<float>(),
shift.data_ptr<float>(),
conv_weight.data_ptr<float>(),
output.data_ptr<float>(),
B, C_in, H, W, C_out
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused BatchNorm2d (with precomputed scale/shift), ReLU, 1x1 Conv2d, and AvgPool2d");
}