3.7. Inception Forward Computation
Performs a forward pass through a deep convolutional architecture that extracts multi-scale features. It computes parallel operations with 1x1, 3x3, and 5x5 convolutions along with a pooling branch in an inception module, concatenates these features, and then processes them through additional convolution, pooling, and fully connected operations to calculate final classification scores.
Progress
Optimization Timeline
Step 1 Details
Reasoning
The Inception module in GoogLeNet computes four parallel branches—two convolutions (one simple 1×1, another a 1×1 reduction followed by a 3×3 or 5×5 convolution) and one pooling branch—whose outputs are concatenated along the channel dimension. Although the individual convolution operations are already highly optimized by cuDNN, the final concatenation (torch.cat) is launched as a separate kernel and may cause extra memory traffic and kernel-launch overhead. We can fuse the concatenation across the four branches in a single custom CUDA kernel, which “copies” the already computed branch outputs into a contiguous output tensor. Our fused CUDA operator “inception_concat” takes as input four tensors (from the four branches) of shape [B, C_i, H, W] (where H and W are identical for all, but channel counts may differ) and writes to an output tensor of shape [B, C_total, H, W] in a single pass. Each output element is computed by determining which branch it belongs to, computing the corresponding source index, and then copying its value. We use c10::cuda::CUDAGuard to ensure proper device handling and launch the kernel with an appropriate grid-stride loop. In the Python module we build a custom CUDAModel that replaces the original InceptionModule’s concatenation logic with a call to our fused concatenation operator. This custom model uses the same architecture as the original GoogLeNet/Inception network but only replaces the torch.cat operation in the Inception modules. The rest of the operations remain unchanged. The approach genuinely helps by reducing the number of separate kernel launches and optimizing memory accesses for the concatenation step – a small but non‐negligible overhead in such a complex architecture. We now present the final Python module and CUDA kernel.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// This kernel performs fused concatenation along the channel dimension
// for four input tensors of shape [B, C_i, H, W].
// It writes into an output tensor of shape [B, C_total, H, W],
// where C_total = C1 + C2 + C3 + C4.
__global__ void inception_concat_kernel(
const float* __restrict__ inp1, int c1,
const float* __restrict__ inp2, int c2,
const float* __restrict__ inp3, int c3,
const float* __restrict__ inp4, int c4,
float* __restrict__ out,
int B, int H, int W) {
int total_channels = c1 + c2 + c3 + c4;
int total = B * total_channels * H * W;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int index = idx; index < total; index += blockDim.x * gridDim.x) {
// Compute indices in output tensor: index corresponds to (b, c, h, w)
int tmp = index;
int w = tmp % W;
tmp /= W;
int h = tmp % H;
tmp /= H;
int c = tmp % total_channels;
int b = tmp / total_channels;
// Determine which input tensor contains this channel.
const float* src = nullptr;
int local_c = 0;
if (c < c1) {
src = inp1;
local_c = c;
} else if (c < c1 + c2) {
src = inp2;
local_c = c - c1;
} else if (c < c1 + c2 + c3) {
src = inp3;
local_c = c - (c1 + c2);
} else {
src = inp4;
local_c = c - (c1 + c2 + c3);
}
// Compute source index: each input is [B, C_i, H, W]
int src_channels = 0;
if (src == inp1) { src_channels = c1; }
else if (src == inp2) { src_channels = c2; }
else if (src == inp3) { src_channels = c3; }
else { src_channels = c4; }
int src_index = b * (src_channels * H * W) + local_c * (H * W) + h * W + w;
out[index] = src[src_index];
}
}
torch::Tensor inception_concat(torch::Tensor inp1, torch::Tensor inp2,
torch::Tensor inp3, torch::Tensor inp4) {
// Ensure all tensors are on the same CUDA device.
c10::cuda::CUDAGuard device_guard(inp1.device());
// Check dimensions: all should be 4D: [B, C, H, W]
TORCH_CHECK(inp1.dim() == 4, "inp1 must be a 4D tensor");
TORCH_CHECK(inp2.dim() == 4, "inp2 must be a 4D tensor");
TORCH_CHECK(inp3.dim() == 4, "inp3 must be a 4D tensor");
TORCH_CHECK(inp4.dim() == 4, "inp4 must be a 4D tensor");
int B = inp1.size(0);
int H = inp1.size(2);
int W = inp1.size(3);
TORCH_CHECK(inp2.size(0) == B && inp3.size(0) == B && inp4.size(0) == B, "Batch size mismatch");
TORCH_CHECK(inp1.size(2) == inp2.size(2) && inp1.size(2) == inp3.size(2) && inp1.size(2) == inp4.size(2),
"Height mismatch");
TORCH_CHECK(inp1.size(3) == inp2.size(3) && inp1.size(3) == inp3.size(3) && inp1.size(3) == inp4.size(3),
"Width mismatch");
int c1 = inp1.size(1);
int c2 = inp2.size(1);
int c3 = inp3.size(1);
int c4 = inp4.size(1);
int total_channels = c1 + c2 + c3 + c4;
// Allocate output tensor of shape [B, total_channels, H, W]
auto output = torch::empty({B, total_channels, H, W}, inp1.options());
int total = B * total_channels * H * W;
int threads = 256;
int blocks = (total + threads - 1) / threads;
inception_concat_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
inp1.data_ptr<float>(), c1,
inp2.data_ptr<float>(), c2,
inp3.data_ptr<float>(), c3,
inp4.data_ptr<float>(), c4,
output.data_ptr<float>(),
B, H, W);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("inception_concat", &inception_concat, "Fused concatenation for Inception module branches");
}