3.20. Mobile Inverted Residual Network

Calculates an efficient mobile-optimized neural network that performs a series of standard, depthwise, and pointwise convolutions combined with channel normalization and ReLU6 activations. It dynamically adjusts channel dimensions, leverages inverted residual blocks with skip connections, and aggregates spatial features through adaptive pooling before a final linear transformation for classification.

Progress

Step 3 of 3
Reward
1.00x

Optimization Timeline

Step 3 Details

Reasoning

We observe that our fused BN+ReLU6 operator has been further optimized with vectorized loads/stores. However, the overall MobileNetV2 forward pass still calls our fused operator in inference mode for the first and last convolution blocks. Our measured runtime is nearly equal to the native implementation because those blocks are relatively small. To push the speedup further, we can fuse BatchNorm parameters into the convolution weights when in inference mode – a standard inference-time optimization that eliminates the BN layer and its extra memory load/store. This means that in eval mode we compute a “fused” convolution (with adjusted weights and bias) for the first and last conv layers, followed by ReLU6 activation. By doing this, we remove an entire GPU kernel launch (our fused BN+ReLU6 kernel) and replace it with a single convolution call followed by ReLU6. Convolution is typically well-optimized using cuDNN. This is a genuine improvement and is a well-known strategy (BN folding) that does not compromise correctness. In our new Python module we add helper method fuse_conv_bn() that computes new weights and bias. Then in forward(), if self.training is False we use fused convolution for conv1 and conv_last instead of calling our fused BN+ReLU6 custom operator. The rest of the network remains unchanged. We leave the CUDA extension unchanged. Thus, our improvements genuinely help by eliminating extra memory operations in inference.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Scalar kernel: processes one element per thread.
__global__ void fused_bn_relu6_kernel(const float* __restrict__ input,
                                        float* __restrict__ output,
                                        const float* __restrict__ gamma,
                                        const float* __restrict__ beta,
                                        const float* __restrict__ running_mean,
                                        const float* __restrict__ running_var,
                                        const int N, const int C, const int H, const int W,
                                        const float eps) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * C * H * W;
    if (idx < total) {
        // Determine channel index based on NCHW layout.
        int c = (idx / (H * W)) % C;
        float x_val = input[idx];
        float norm = (x_val - running_mean[c]) / sqrtf(running_var[c] + eps);
        float y = gamma[c] * norm + beta[c];
        // Apply ReLU6.
        y = fminf(fmaxf(y, 0.0f), 6.0f);
        output[idx] = y;
    }
}

// Vectorized kernel: processes 4 elements at a time using float4.
__global__ void fused_bn_relu6_vec_kernel(const float4* __restrict__ input,
                                            float4* __restrict__ output,
                                            const float* __restrict__ gamma,
                                            const float* __restrict__ beta,
                                            const float* __restrict__ running_mean,
                                            const float* __restrict__ running_var,
                                            const int N, const int C, const int H, const int W,
                                            const float eps) {
    // Each thread processes a float4 vector from a channel block.
    int inner = H * W;
    int inner_vec = inner / 4; // Number of float4's per (N,C) block.
    int total_vec = N * C * inner_vec;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total_vec) {
        // Compute indices: position within the (N, C, H*W) tensor.
        int tmp = idx;
        int pos = tmp % inner_vec;  // position in the vectorized block (0 ... inner_vec-1)
        tmp /= inner_vec;
        int c = tmp % C;
        int n = tmp / C;
        // Load the vectorized data.
        float4 val = input[idx];
        float mean = running_mean[c];
        float var = running_var[c];
        float gamma_c = gamma[c];
        float beta_c = beta[c];
        float inv_std = rsqrtf(var + eps);
        // Process each of the 4 elements in the float4.
        float x0 = (val.x - mean) * inv_std;
        float x1 = (val.y - mean) * inv_std;
        float x2 = (val.z - mean) * inv_std;
        float x3 = (val.w - mean) * inv_std;
        float y0 = gamma_c * x0 + beta_c;
        float y1 = gamma_c * x1 + beta_c;
        float y2 = gamma_c * x2 + beta_c;
        float y3 = gamma_c * x3 + beta_c;
        y0 = fminf(fmaxf(y0, 0.0f), 6.0f);
        y1 = fminf(fmaxf(y1, 0.0f), 6.0f);
        y2 = fminf(fmaxf(y2, 0.0f), 6.0f);
        y3 = fminf(fmaxf(y3, 0.0f), 6.0f);
        float4 res = { y0, y1, y2, y3 };
        output[idx] = res;
    }
}

torch::Tensor fused_bn_relu6(torch::Tensor input,
                             torch::Tensor gamma,
                             torch::Tensor beta,
                             torch::Tensor running_mean,
                             torch::Tensor running_var,
                             float eps) {
    // Ensure input tensor is contiguous.
    auto input_contig = input.contiguous();
    const auto N = input_contig.size(0);
    const auto C = input_contig.size(1);
    const auto H = input_contig.size(2);
    const auto W = input_contig.size(3);
    auto output = torch::empty_like(input_contig);
    int total = input_contig.numel();
    int threads = 256;
    c10::cuda::CUDAGuard device_guard(input_contig.device());
    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();

    // If the inner dimension (H*W) is divisible by 4, use vectorized kernel.
    if ((H * W) % 4 == 0) {
        int inner_vec = (H * W) / 4;
        int total_vec = N * C * inner_vec;
        int blocks = (total_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input_contig.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        fused_bn_relu6_vec_kernel<<<blocks, threads, 0, stream>>>(
            input_vec, output_vec,
            gamma.data_ptr<float>(),
            beta.data_ptr<float>(),
            running_mean.data_ptr<float>(),
            running_var.data_ptr<float>(),
            N, C, H, W, eps);
    } else {
        int blocks = (total + threads - 1) / threads;
        fused_bn_relu6_kernel<<<blocks, threads, 0, stream>>>(
            input_contig.data_ptr<float>(),
            output.data_ptr<float>(),
            gamma.data_ptr<float>(),
            beta.data_ptr<float>(),
            running_mean.data_ptr<float>(),
            running_var.data_ptr<float>(),
            N, C, H, W, eps);
    }
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_bn_relu6", &fused_bn_relu6, "Fused BatchNorm and ReLU6 Kernel (with vectorized support for inference)");
}