3.21. Mobile Inverted Bottleneck Convolution Block
Calculates an MBConv operation by conditionally expanding the input channels using a point-wise convolution, applying a depthwise convolution with configurable kernel size and stride, and then projecting the result to the target channel dimension. It includes a residual connection when the input and output dimensions match.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our fused depthwise convolution kernel by reducing redundant computations and using the __ldg intrinsic for read‐only loads (e.g. for BN parameters, weight, and input values). In each thread we already compute convolution “sum” over the receptive field; however, we can precompute the BN normalization factors: instead of computing bn_out = bn_weight[c] * ((sum - bn_mean[c]) / sqrt(bn_var[c] + eps)) + bn_bias[c] for every thread, we compute once per thread: bn_scale = bn_weight[c] / sqrt(bn_var[c] + eps) bn_shift = bn_bias[c] - bn_mean[c] * bn_scale and then obtain result = sum * bn_scale + bn_shift. Using __ldg for accesses to bn_* arrays and weight (and even input) should help ensure that these values are fetched via the read-only cache, reducing memory latency. Adding a "#pragma unroll" directive on the small kernel loops might also help if the kernel size is known (or small) at runtime. These adjustments can help lower the per-thread cost, producing an overall speedup. Before finalizing, I check that these improvements are genuine and not merely superficial wrappers. The changes reduce arithmetic operations and leverage caching, both of which are sound optimizations.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused depthwise convolution kernel with BatchNorm and ReLU6 activation,
// optimized with pre-computation of BN factors and __ldg for read-only loads.
// Input: tensor [N, C, H_in, W_in]
// Weight: depthwise conv weights of shape [C, k, k] (flattened as [C, k*k])
// BN parameters: bn_weight, bn_bias, bn_mean, bn_var (each of shape [C])
// eps: epsilon for BN, stride, pad, and kernel size k.
// Output: tensor [N, C, H_out, W_out]
__global__ void fused_dwconv_kernel(const float* __restrict__ input,
const float* __restrict__ weight,
const float* __restrict__ bn_weight,
const float* __restrict__ bn_bias,
const float* __restrict__ bn_mean,
const float* __restrict__ bn_var,
float eps,
int stride,
int pad,
int N, int C, int H_in, int W_in,
int k,
int H_out, int W_out,
float* __restrict__ output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = N * C * H_out * W_out;
if (index < total) {
// Determine indices for n, c, h_out, w_out.
int n = index / (C * H_out * W_out);
int rem = index % (C * H_out * W_out);
int c = rem / (H_out * W_out);
int pos = rem % (H_out * W_out);
int h_out = pos / W_out;
int w_out = pos % W_out;
// Precompute BN normalization factors for channel c.
float bn_var_val = __ldg(&bn_var[c]);
float inv_std = 1.0f / sqrtf(bn_var_val + eps);
float w_bn = __ldg(&bn_weight[c]); // load bn_weight[c]
float bn_scale = w_bn * inv_std; // scale factor
float bn_shift = __ldg(&bn_bias[c]) - __ldg(&bn_mean[c]) * bn_scale;
float sum = 0.0f;
// Compute convolution sum over the k x k kernel window.
// Optionally unroll small loops.
#pragma unroll
for (int i = 0; i < k; i++) {
#pragma unroll
for (int j = 0; j < k; j++) {
int h_in = h_out * stride - pad + i;
int w_in = w_out * stride - pad + j;
if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
int input_index = n * (C * H_in * W_in) + c * (H_in * W_in) + h_in * W_in + w_in;
int weight_index = c * (k * k) + i * k + j;
float in_val = __ldg(&input[input_index]);
float w_val = __ldg(&weight[weight_index]);
sum += in_val * w_val;
}
}
}
// Apply BatchNorm normalization and ReLU6 activation.
float bn_out = sum * bn_scale + bn_shift;
float activated = fminf(fmaxf(bn_out, 0.0f), 6.0f);
output[index] = activated;
}
}
torch::Tensor fused_dwconv(torch::Tensor input,
torch::Tensor weight,
torch::Tensor bn_weight,
torch::Tensor bn_bias,
torch::Tensor bn_mean,
torch::Tensor bn_var,
float eps,
int stride,
int pad,
int k) {
c10::cuda::CUDAGuard device_guard(input.device());
TORCH_CHECK(input.dim() == 4, "input must be a 4D tensor");
int N = input.size(0);
int C = input.size(1);
int H_in = input.size(2);
int W_in = input.size(3);
int H_out = (H_in + 2 * pad - k) / stride + 1;
int W_out = (W_in + 2 * pad - k) / stride + 1;
auto output = torch::empty({N, C, H_out, W_out}, input.options());
int total = N * C * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_dwconv_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
bn_weight.data_ptr<float>(),
bn_bias.data_ptr<float>(),
bn_mean.data_ptr<float>(),
bn_var.data_ptr<float>(),
eps,
stride,
pad,
N, C, H_in, W_in,
k,
H_out, W_out,
output.data_ptr<float>()
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_dwconv", &fused_dwconv, "Fused depthwise conv + BatchNorm + ReLU6 (CUDA)");
}