3.12. Deep Convolutional Classification
Calculates image classification by applying a deep convolutional architecture that processes input images through five sequential blocks of convolution, activation, and pooling operations to extract features; then flattens these features and performs fully connected computations to generate class predictions.
Progress
Optimization Timeline
Step 5 Details
Reasoning
We can further improve our fused ReLU+MaxPool kernel by reducing memory latency. One easy optimization is to use the __ldg intrinsic to load input values via the read‐only cache which may reduce memory access overhead. Additionally, our previous implementation already fused two operations (ReLU and max pooling) which resulted in a significant reduction in runtime. With the use of __ldg for each input load, we can further squeeze performance from our kernel. The Python module remains unchanged – it simply calls our fused operator. This improvement genuinely helps by reducing memory latency without “cheating” – it’s a standard optimization for read‐only data on CUDA devices. Now we provide our final updated code.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Fused ReLU and 2x2 Max Pooling Kernel.
// This kernel assumes input tensor of shape [N, C, H, W] (with even H and W).
// It applies ReLU (clamping negatives to 0) and then performs 2x2 max pooling,
// resulting in an output tensor of shape [N, C, H/2, W/2].
__global__ void fused_relu_maxpool_kernel(
const float* __restrict__ input,
float* __restrict__ output,
int N, int C, int H, int W,
int H_out, int W_out)
{
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = N * C * H_out * W_out;
if (index < total)
{
// Decode output coordinates.
int w_out = index % W_out;
int tmp = index / W_out;
int h_out = tmp % H_out;
tmp /= H_out;
int c = tmp % C;
int n = tmp / C;
// Compute input coordinates for the 2x2 window.
int h_in = h_out * 2;
int w_in = w_out * 2;
int in_base = n * (C * H * W) + c * (H * W);
// Load values using __ldg to benefit from read-only cache.
float v0 = __ldg(&input[in_base + (h_in) * W + (w_in)]);
float v1 = __ldg(&input[in_base + (h_in) * W + (w_in + 1)]);
float v2 = __ldg(&input[in_base + (h_in + 1) * W + (w_in)]);
float v3 = __ldg(&input[in_base + (h_in + 1) * W + (w_in + 1)]);
// Apply ReLU: clamp negatives to 0.
v0 = fmaxf(v0, 0.0f);
v1 = fmaxf(v1, 0.0f);
v2 = fmaxf(v2, 0.0f);
v3 = fmaxf(v3, 0.0f);
// Compute maximum of the 2x2 window.
float max_val = fmaxf(fmaxf(v0, v1), fmaxf(v2, v3));
output[index] = max_val;
}
}
torch::Tensor fused_relu_maxpool(torch::Tensor input) {
c10::cuda::CUDAGuard device_guard(input.device());
TORCH_CHECK(input.dim() == 4, "Input must be 4D [N, C, H, W]");
int N = input.size(0);
int C = input.size(1);
int H = input.size(2);
int W = input.size(3);
TORCH_CHECK(H % 2 == 0 && W % 2 == 0, "Height and Width must be even for 2x2 pooling");
int H_out = H / 2;
int W_out = W / 2;
auto output = torch::empty({N, C, H_out, W_out}, input.options());
int total = N * C * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_relu_maxpool_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
N, C, H, W,
H_out, W_out
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_relu_maxpool", &fused_relu_maxpool, "Fused ReLU and 2x2 Max Pooling kernel");
}