1.82. Depthwise 2D Convolution (Square Input/Kernel)

Performs a depthwise convolution on a square input tensor using a square kernel. It calculates channel-wise outputs by convolving each input channel independently, with parameters for adjustable stride, padding, and an optional bias.

Progress

Step 7 of 7
Reward
1.72x

Optimization Timeline

Step 7 Details

Reasoning

We considered exploring shared memory tiling and warp-level primitives for further improvements, but the increased complexity and potential for register pressure did not promise a clear win over our current highly optimized kernel. Our current implementation already provides a fast-path for the common case (kernel_size == 3, stride == 1, padding == 0) using complete loop unrolling, register caching for weight values, and coalesced global reads with __ldg in the generic path. These enhancements have proven to yield a significant speedup (~1.72×) compared to the native PyTorch implementation. The potential additional optimizations would require nontrivial code changes and may not be beneficial for our typical use-cases. This final version balances complexity and performance while remaining mathematically equivalent and efficient.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Depthwise 2D convolution kernel with loop unrolling, register caching, and optional unrolling directives.
// Input: [batch, channels, H, W]
// Weight: [channels, 1, K, K] where K = kernel_size.
// Output: [batch, channels, H_out, W_out]
//      H_out = (H + 2*padding - K) / stride + 1, similarly for W_out.
__global__ void depthwise_conv2d_kernel(const float* __restrict__ input,
                                        const float* __restrict__ weight,
                                        float* __restrict__ output,
                                        int batch,
                                        int channels,
                                        int H,
                                        int W,
                                        int H_out,
                                        int W_out,
                                        int K,
                                        int stride,
                                        int padding) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch * channels * H_out * W_out;
    if (index >= total)
        return;
    
    // Map linear index to (b, c, h_out, w_out)
    int w_out = index % W_out;
    int tmp = index / W_out;
    int h_out = tmp % H_out;
    tmp /= H_out;
    int c = tmp % channels;
    int b = tmp / channels;
    
    float sum = 0.0f;
    int h_in_start = h_out * stride - padding;
    int w_in_start = w_out * stride - padding;
    
    // Optimized branch for kernel_size == 3, stride == 1, and padding == 0.
    if (K == 3 && stride == 1 && padding == 0) {
        // Load weight elements for channel c into registers.
        float w00 = weight[((c * 3) + 0) * 3 + 0],
              w01 = weight[((c * 3) + 0) * 3 + 1],
              w02 = weight[((c * 3) + 0) * 3 + 2],
              w10 = weight[((c * 3) + 1) * 3 + 0],
              w11 = weight[((c * 3) + 1) * 3 + 1],
              w12 = weight[((c * 3) + 1) * 3 + 2],
              w20 = weight[((c * 3) + 2) * 3 + 0],
              w21 = weight[((c * 3) + 2) * 3 + 1],
              w22 = weight[((c * 3) + 2) * 3 + 2];
        
        // In this ideal case, the input indices are valid without bounds checking.
        int h0 = h_out, h1 = h_out + 1, h2 = h_out + 2;
        int w0 = w_out, w1 = w_out + 1, w2 = w_out + 1 + 0; // w2 = w_out + 2
        w2 = w_out + 2;
        int base = ((b * channels + c) * H);
        int offset0 = (base + h0) * W,
            offset1 = (base + h1) * W,
            offset2 = (base + h2) * W;
        
        sum = input[offset0 + w0] * w00 +
              input[offset0 + w1] * w01 +
              input[offset0 + w2] * w02 +
              input[offset1 + w0] * w10 +
              input[offset1 + w1] * w11 +
              input[offset1 + w2] * w12 +
              input[offset2 + w0] * w20 +
              input[offset2 + w1] * w21 +
              input[offset2 + w2] * w22;
    }
    // Optimized branch for K == 3 with non-zero padding or stride != 1.
    else if (K == 3) {
        float w00 = weight[((c * 3) + 0) * 3 + 0],
              w01 = weight[((c * 3) + 0) * 3 + 1],
              w02 = weight[((c * 3) + 0) * 3 + 2],
              w10 = weight[((c * 3) + 1) * 3 + 0],
              w11 = weight[((c * 3) + 1) * 3 + 1],
              w12 = weight[((c * 3) + 1) * 3 + 2],
              w20 = weight[((c * 3) + 2) * 3 + 0],
              w21 = weight[((c * 3) + 2) * 3 + 1],
              w22 = weight[((c * 3) + 2) * 3 + 2];
        
        int h0 = h_in_start + 0, h1 = h_in_start + 1, h2 = h_in_start + 2;
        int w0 = w_in_start + 0, w1 = w_in_start + 1, w2 = w_in_start + 2;
        int base = ((b * channels + c) * H);
        if (h0 >= 0 && h0 < H) {
            int offset = (base + h0) * W;
            if (w0 >= 0 && w0 < W)
                sum += input[offset + w0] * w00;
            if (w1 >= 0 && w1 < W)
                sum += input[offset + w1] * w01;
            if (w2 >= 0 && w2 < W)
                sum += input[offset + w2] * w02;
        }
        if (h1 >= 0 && h1 < H) {
            int offset = (base + h1) * W;
            if (w0 >= 0 && w0 < W)
                sum += input[offset + w0] * w10;
            if (w1 >= 0 && w1 < W)
                sum += input[offset + w1] * w11;
            if (w2 >= 0 && w2 < W)
                sum += input[offset + w2] * w12;
        }
        if (h2 >= 0 && h2 < H) {
            int offset = (base + h2) * W;
            if (w0 >= 0 && w0 < W)
                sum += input[offset + w0] * w20;
            if (w1 >= 0 && w1 < W)
                sum += input[offset + w1] * w21;
            if (w2 >= 0 && w2 < W)
                sum += input[offset + w2] * w22;
        }
    }
    // Generic implementation for arbitrary kernel sizes.
    else {
        #pragma unroll
        for (int i = 0; i < K; i++) {
            int h_in = h_in_start + i;
            if (h_in < 0 || h_in >= H) continue;
            #pragma unroll
            for (int j = 0; j < K; j++) {
                int w_in = w_in_start + j;
                if (w_in < 0 || w_in >= W) continue;
                int input_idx = ((b * channels + c) * H + h_in) * W + w_in;
                int weight_idx = ((c * K) + i) * K + j;
                sum += __ldg(&input[input_idx]) * __ldg(&weight[weight_idx]);
            }
        }
    }
    output[index] = sum;
}

torch::Tensor depthwise_conv2d(torch::Tensor input,
                               torch::Tensor weight,
                               int stride,
                               int padding) {
    c10::cuda::CUDAGuard device_guard(input.device());
    int batch = input.size(0);
    int channels = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    int K = weight.size(2);
    
    int H_out = (H + 2 * padding - K) / stride + 1;
    int W_out = (W + 2 * padding - K) / stride + 1;
    
    auto output = torch::empty({batch, channels, H_out, W_out}, input.options());
    int total_output = batch * channels * H_out * W_out;
    int threads = 256;
    int blocks = (total_output + threads - 1) / threads;
    
    depthwise_conv2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        output.data_ptr<float>(),
        batch,
        channels,
        H,
        W,
        H_out,
        W_out,
        K,
        stride,
        padding
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("depthwise_conv2d", &depthwise_conv2d, "Depthwise 2D Convolution (CUDA) with loop unrolling and register caching");
}