3.22. EfficientNetB0 Computation
Calculates image classification outputs by performing an initial strided convolution followed by a sequence of mobile inverted bottleneck (MBConv) blocks, which optionally expand channel dimensions, apply depthwise convolutions with ReLU6 activations and batch normalization, and incorporate residual connections when possible. The operation concludes with a final convolution, adaptive average pooling, and a fully connected layer to compute the output class scores.
Progress
Optimization Timeline
Step 4 Details
Reasoning
The previous version already vectorized the dot product using float4 loads, but the inner loop still iterated over each float4 element sequentially. In many cases, the dimension C is large (e.g. 1280) and divisible by 4. We can improve the inner loop performance by manually unrolling it. This reduces loop overhead and allows the compiler to schedule instructions more efficiently. In our new kernel, when (C & 3)==0 we calculate vecC = C/4 and then unroll the loop in groups of 4 iterations if possible. If there is any remainder, we fall back to a scalar loop. We also add a #pragma unroll directive where applicable for additional hints. The Python module needs only to register the updated kernel. These improvements should reduce the per-output dot product cost while still fusing average pooling and FC computation in one kernel.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// This fused kernel performs adaptive average pooling (over H and W) and then computes the
// FC output in one operation. Each block processes one sample (n). It first computes the pooled
// feature vector for that sample into shared memory, then computes the dot product with each FC weight.
// In the vectorized branch (when C is divisible by 4), we manually unroll the loop for better performance.
extern "C" __global__
void fused_avgpool_fc_kernel(const float* __restrict__ input,
const float* __restrict__ fc_weight,
const float* __restrict__ fc_bias,
float* __restrict__ output,
int N, int C, int H, int W, int K) {
// Each block is responsible for one sample.
int n = blockIdx.x;
int tid = threadIdx.x;
int spatial = H * W;
float inv_spatial = 1.0f / spatial;
// Shared memory to store pooled results, size = C floats.
extern __shared__ float sP[];
// Phase 1: Compute average pooled vector for sample n.
for (int c = tid; c < C; c += blockDim.x) {
float sum = 0.0f;
int base = n * C * spatial + c * spatial;
#pragma unroll
for (int i = 0; i < spatial; i++) {
sum += input[base + i];
}
sP[c] = sum * inv_spatial;
}
__syncthreads();
// Phase 2: Compute FC dot products for sample n.
for (int k = tid; k < K; k += blockDim.x) {
float dot = 0.0f;
// If C is divisible by 4, use vectorized dot product with manual unrolling.
if ((C & 3) == 0) {
int vecC = C >> 2; // C/4 elements
const float4* w_vec = reinterpret_cast<const float4*>(&fc_weight[k * C]);
const float4* sP_vec = reinterpret_cast<const float4*>(sP);
int i = 0;
int unroll_count = (vecC / 4) * 4; // largest multiple of 4
#pragma unroll
for (; i < unroll_count; i += 4) {
float4 w0 = __ldg(&w_vec[i]);
float4 w1 = __ldg(&w_vec[i+1]);
float4 w2 = __ldg(&w_vec[i+2]);
float4 w3 = __ldg(&w_vec[i+3]);
float4 p0 = sP_vec[i];
float4 p1 = sP_vec[i+1];
float4 p2 = sP_vec[i+2];
float4 p3 = sP_vec[i+3];
dot += w0.x * p0.x + w0.y * p0.y + w0.z * p0.z + w0.w * p0.w;
dot += w1.x * p1.x + w1.y * p1.y + w1.z * p1.z + w1.w * p1.w;
dot += w2.x * p2.x + w2.y * p2.y + w2.z * p2.z + w2.w * p2.w;
dot += w3.x * p3.x + w3.y * p3.y + w3.z * p3.z + w3.w * p3.w;
}
// Process remaining elements.
#pragma unroll
for (; i < vecC; i++) {
float4 w_val = __ldg(&w_vec[i]);
float4 p_val = sP_vec[i];
dot += w_val.x * p_val.x +
w_val.y * p_val.y +
w_val.z * p_val.z +
w_val.w * p_val.w;
}
} else {
// Fallback to scalar loop if C is not divisible by 4.
for (int c = 0; c < C; c++) {
dot += __ldg(&fc_weight[k * C + c]) * sP[c];
}
}
output[n * K + k] = dot + __ldg(&fc_bias[k]);
}
}
torch::Tensor fused_avgpool_fc(torch::Tensor input, torch::Tensor fc_weight, torch::Tensor fc_bias) {
// Use CUDAGuard to ensure we operate on the correct device.
c10::cuda::CUDAGuard device_guard(input.device());
TORCH_CHECK(input.dim() == 4, "Input tensor must be 4D");
TORCH_CHECK(fc_weight.dim() == 2, "FC weight must be 2D");
TORCH_CHECK(fc_bias.dim() == 1, "FC bias must be 1D");
int N = input.size(0);
int C = input.size(1);
int H = input.size(2);
int W = input.size(3);
int K = fc_weight.size(0);
auto output = torch::empty({N, K}, input.options());
// Launch one block per sample.
int blockSize = 256;
dim3 blocks(N);
size_t sharedMemSize = C * sizeof(float);
fused_avgpool_fc_kernel<<<blocks, blockSize, sharedMemSize, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
fc_weight.data_ptr<float>(),
fc_bias.data_ptr<float>(),
output.data_ptr<float>(),
N, C, H, W, K
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_avgpool_fc", &fused_avgpool_fc, "Fused adaptive pooling and FC for EfficientNetB0 final layers (with unrolled vectorized dot product)");
}