3.6. Inception Module Operation
Calculates multi-scale features by concurrently performing four parallel operations on the input: a 1x1 convolution; a 1x1 convolution followed by a 3x3 convolution; a 1x1 convolution followed by a 5x5 convolution; and a max pooling followed by a 1x1 convolution. The outputs of these branches are then concatenated to form a unified feature tensor.
Progress
Optimization Timeline
Step 2 Details
Reasoning
Our original fused Inception concatenation kernel iterates over every float element and copies it from the correct input branch. Even though it is correctly implemented, it processes one float at a time. Since our spatial size (H×W) is large and (typically) divisible by 4, we can vectorize the kernel using float4 loads and stores. This minimizes the number of memory transactions by a factor of 4 and may improve throughput by better memory coalescing. We achieve this by reinterpreting the input and output pointers as float4 pointers. We then adjust the indexing: for each sample, the total number of float elements is (C_total × H × W) and the number of vectorized (float4) elements is that divided by 4. For each vector element we compute its corresponding position in the fused output and decide which branch it belongs to. We then use vectorized copy from the right branch. The host code will check if (H×W % 4 == 0) and call our vectorized fused kernel; otherwise, it will fall back to the original kernel. This vectorized approach should reduce the number of iterations and boost performance compared to processing each float individually. Our Python module remains largely unchanged aside from calling the improved kernel.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Original non-vectorized fused concatenation kernel.
__global__ void fused_cat_kernel(const float* __restrict__ in1, int C1,
const float* __restrict__ in2, int C2,
const float* __restrict__ in3, int C3,
const float* __restrict__ in4, int C4,
float* __restrict__ out,
int N, int H, int W) {
int C_total = C1 + C2 + C3 + C4;
int HW = H * W;
int sample_size = C_total * HW;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = N * sample_size;
if(index < total) {
int n = index / sample_size;
int rem = index % sample_size;
int c = rem / HW;
int pos = rem % HW;
float value = 0.0f;
if(c < C1) {
int offset = n * C1 * HW + c * HW + pos;
value = in1[offset];
} else if(c < C1 + C2) {
int c_new = c - C1;
int offset = n * C2 * HW + c_new * HW + pos;
value = in2[offset];
} else if(c < C1 + C2 + C3) {
int c_new = c - C1 - C2;
int offset = n * C3 * HW + c_new * HW + pos;
value = in3[offset];
} else {
int c_new = c - C1 - C2 - C3;
int offset = n * C4 * HW + c_new * HW + pos;
value = in4[offset];
}
out[index] = value;
}
}
// Optimized vectorized fused concatenation kernel using float4.
// Assumes that H*W is divisible by 4.
__global__ void fused_cat_vec_kernel(const float* __restrict__ in1, int C1,
const float* __restrict__ in2, int C2,
const float* __restrict__ in3, int C3,
const float* __restrict__ in4, int C4,
float* __restrict__ out,
int N, int H, int W) {
int HW = H * W; // Number of spatial elements per channel.
int block0 = C1 * HW; // Elements count in branch1.
int block1 = C2 * HW; // Elements count in branch2.
int block2 = C3 * HW; // Elements count in branch3.
int block3 = C4 * HW; // Elements count in branch4.
int C_total = C1 + C2 + C3 + C4;
int sample_elements = C_total * HW;
// Each sample vector count in terms of float4.
int sample_vec_count = sample_elements / 4;
int global_index = blockIdx.x * blockDim.x + threadIdx.x;
int total_vec = N * sample_vec_count;
if(global_index < total_vec) {
int n = global_index / sample_vec_count;
int local_vec_idx = global_index % sample_vec_count;
int x = local_vec_idx * 4; // Starting float index within sample.
float4 value;
if(x < block0) {
int index_in_branch = x / 4;
const float4* in1_vec = reinterpret_cast<const float4*>(in1 + n * block0);
value = in1_vec[index_in_branch];
} else if(x < block0 + block1) {
int x_branch = x - block0;
int index_in_branch = x_branch / 4;
const float4* in2_vec = reinterpret_cast<const float4*>(in2 + n * block1);
value = in2_vec[index_in_branch];
} else if(x < block0 + block1 + block2) {
int x_branch = x - (block0 + block1);
int index_in_branch = x_branch / 4;
const float4* in3_vec = reinterpret_cast<const float4*>(in3 + n * block2);
value = in3_vec[index_in_branch];
} else {
int x_branch = x - (block0 + block1 + block2);
int index_in_branch = x_branch / 4;
const float4* in4_vec = reinterpret_cast<const float4*>(in4 + n * block3);
value = in4_vec[index_in_branch];
}
// Write value to output.
float4* out_vec = reinterpret_cast<float4*>(out);
out_vec[global_index] = value;
}
}
torch::Tensor forward(torch::Tensor in1, torch::Tensor in2,
torch::Tensor in3, torch::Tensor in4) {
c10::cuda::CUDAGuard device_guard(in1.device());
TORCH_CHECK(in1.dim() == 4 && in2.dim() == 4 &&
in3.dim() == 4 && in4.dim() == 4, "All inputs must be 4D");
int N = in1.size(0);
int H = in1.size(2);
int W = in1.size(3);
TORCH_CHECK(in2.size(0) == N && in3.size(0) == N && in4.size(0) == N,
"Batch sizes must match");
TORCH_CHECK(in1.size(2) == H && in1.size(3) == W &&
in2.size(2) == H && in2.size(3) == W &&
in3.size(2) == H && in3.size(3) == W &&
in4.size(2) == H && in4.size(3) == W,
"Spatial dimensions must match");
int C1 = in1.size(1);
int C2 = in2.size(1);
int C3 = in3.size(1);
int C4 = in4.size(1);
int C_total = C1 + C2 + C3 + C4;
int HW = H * W;
int sample_elements = C_total * HW;
auto output = torch::empty({N, C_total, H, W}, in1.options());
int threads = 256;
// If the spatial size is divisible by 4, use vectorized kernel.
if ((HW % 4) == 0 && (sample_elements % 4) == 0) {
int sample_vec_count = sample_elements / 4;
int total_vec = N * sample_vec_count;
int blocks = (total_vec + threads - 1) / threads;
fused_cat_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
in1.data_ptr<float>(), C1,
in2.data_ptr<float>(), C2,
in3.data_ptr<float>(), C3,
in4.data_ptr<float>(), C4,
output.data_ptr<float>(),
N, H, W
);
} else {
int total = N * sample_elements;
int blocks = (total + threads - 1) / threads;
fused_cat_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
in1.data_ptr<float>(), C1,
in2.data_ptr<float>(), C2,
in3.data_ptr<float>(), C3,
in4.data_ptr<float>(), C4,
output.data_ptr<float>(),
N, H, W
);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused concatenation of 4 tensors along channel dimension (with vectorized support)");
}