3.21. Mobile Inverted Bottleneck Convolution Block

Calculates an MBConv operation by conditionally expanding the input channels using a point-wise convolution, applying a depthwise convolution with configurable kernel size and stride, and then projecting the result to the target channel dimension. It includes a residual connection when the input and output dimensions match.

Progress

Step 2 of 2
Reward
1.01x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize our fused depthwise convolution kernel by reducing redundant computations and using the __ldg intrinsic for read‐only loads (e.g. for BN parameters, weight, and input values). In each thread we already compute convolution “sum” over the receptive field; however, we can precompute the BN normalization factors: instead of computing   bn_out = bn_weight[c] * ((sum - bn_mean[c]) / sqrt(bn_var[c] + eps)) + bn_bias[c] for every thread, we compute once per thread:   bn_scale = bn_weight[c] / sqrt(bn_var[c] + eps)   bn_shift = bn_bias[c] - bn_mean[c] * bn_scale and then obtain   result = sum * bn_scale + bn_shift. Using __ldg for accesses to bn_* arrays and weight (and even input) should help ensure that these values are fetched via the read-only cache, reducing memory latency. Adding a "#pragma unroll" directive on the small kernel loops might also help if the kernel size is known (or small) at runtime. These adjustments can help lower the per-thread cost, producing an overall speedup. Before finalizing, I check that these improvements are genuine and not merely superficial wrappers. The changes reduce arithmetic operations and leverage caching, both of which are sound optimizations.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused depthwise convolution kernel with BatchNorm and ReLU6 activation,
// optimized with pre-computation of BN factors and __ldg for read-only loads.
// Input:  tensor [N, C, H_in, W_in]
// Weight: depthwise conv weights of shape [C, k, k] (flattened as [C, k*k])
// BN parameters: bn_weight, bn_bias, bn_mean, bn_var (each of shape [C])
// eps: epsilon for BN, stride, pad, and kernel size k.
// Output: tensor [N, C, H_out, W_out]
__global__ void fused_dwconv_kernel(const float* __restrict__ input,
                                    const float* __restrict__ weight,
                                    const float* __restrict__ bn_weight,
                                    const float* __restrict__ bn_bias,
                                    const float* __restrict__ bn_mean,
                                    const float* __restrict__ bn_var,
                                    float eps,
                                    int stride,
                                    int pad,
                                    int N, int C, int H_in, int W_in,
                                    int k,
                                    int H_out, int W_out,
                                    float* __restrict__ output) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * C * H_out * W_out;
    if (index < total) {
        // Determine indices for n, c, h_out, w_out.
        int n = index / (C * H_out * W_out);
        int rem = index % (C * H_out * W_out);
        int c = rem / (H_out * W_out);
        int pos = rem % (H_out * W_out);
        int h_out = pos / W_out;
        int w_out = pos % W_out;
        
        // Precompute BN normalization factors for channel c.
        float bn_var_val = __ldg(&bn_var[c]);
        float inv_std = 1.0f / sqrtf(bn_var_val + eps);
        float w_bn = __ldg(&bn_weight[c]); // load bn_weight[c]
        float bn_scale = w_bn * inv_std;   // scale factor
        float bn_shift = __ldg(&bn_bias[c]) - __ldg(&bn_mean[c]) * bn_scale;
        
        float sum = 0.0f;
        // Compute convolution sum over the k x k kernel window.
        // Optionally unroll small loops.
        #pragma unroll
        for (int i = 0; i < k; i++) {
            #pragma unroll
            for (int j = 0; j < k; j++) {
                int h_in = h_out * stride - pad + i;
                int w_in = w_out * stride - pad + j;
                if (h_in >= 0 && h_in < H_in && w_in >= 0 && w_in < W_in) {
                    int input_index = n * (C * H_in * W_in) + c * (H_in * W_in) + h_in * W_in + w_in;
                    int weight_index = c * (k * k) + i * k + j;
                    float in_val = __ldg(&input[input_index]);
                    float w_val = __ldg(&weight[weight_index]);
                    sum += in_val * w_val;
                }
            }
        }
        // Apply BatchNorm normalization and ReLU6 activation.
        float bn_out = sum * bn_scale + bn_shift;
        float activated = fminf(fmaxf(bn_out, 0.0f), 6.0f);
        output[index] = activated;
    }
}

torch::Tensor fused_dwconv(torch::Tensor input,
                           torch::Tensor weight,
                           torch::Tensor bn_weight,
                           torch::Tensor bn_bias,
                           torch::Tensor bn_mean,
                           torch::Tensor bn_var,
                           float eps,
                           int stride,
                           int pad,
                           int k) {
    c10::cuda::CUDAGuard device_guard(input.device());

    TORCH_CHECK(input.dim() == 4, "input must be a 4D tensor");
    int N = input.size(0);
    int C = input.size(1);
    int H_in = input.size(2);
    int W_in = input.size(3);
    
    int H_out = (H_in + 2 * pad - k) / stride + 1;
    int W_out = (W_in + 2 * pad - k) / stride + 1;
    
    auto output = torch::empty({N, C, H_out, W_out}, input.options());
    int total = N * C * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_dwconv_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        bn_weight.data_ptr<float>(),
        bn_bias.data_ptr<float>(),
        bn_mean.data_ptr<float>(),
        bn_var.data_ptr<float>(),
        eps,
        stride,
        pad,
        N, C, H_in, W_in,
        k,
        H_out, W_out,
        output.data_ptr<float>()
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_dwconv", &fused_dwconv, "Fused depthwise conv + BatchNorm + ReLU6 (CUDA)");
}