1.84. Depthwise 2D Convolution with Square Kernel

Calculates a depthwise convolution on a 2D input with potentially asymmetric spatial dimensions using a square kernel. The operation computes channel-wise filtering with configurable stride, padding, and bias.

Progress

Step 6 of 6
Reward
1.85x

Optimization Timeline

Step 6 Details

Reasoning

The current implementation is already well optimized: It unrolls the inner loops for the common case (kernel_size==3, stride==1, padding==0) and uses __ldg for read‐only caching. In addition, the generic branch with loop bounds is in place for other parameters. Any further micro-optimizations (such as shared memory usage or warp-level programming) might not yield a significant benefit given the small kernel size and depthwise nature, while also complicating the code. Thus, our present approach genuinely improves performance and is well balanced between generality and performance. No further modifications are necessary.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Depthwise convolution kernel: each thread computes one output pixel for one channel of one batch.
__global__ void depthwise_conv2d_kernel(const float* __restrict__ input,
                                        const float* __restrict__ weight,
                                        float* __restrict__ output,
                                        int batch,
                                        int channels,
                                        int in_h,
                                        int in_w,
                                        int out_h,
                                        int out_w,
                                        int kernel_size,
                                        int stride,
                                        int padding) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch * channels * out_h * out_w;
    if (index < total) {
        // Compute indices for batch, channel, and output spatial coordinates.
        int tmp = out_h * out_w;
        int n = index / (channels * tmp);
        int rem = index % (channels * tmp);
        int c = rem / tmp;
        int pos = rem % tmp;
        int out_y = pos / out_w;
        int out_x = pos % out_w;
        
        float sum = 0.0f;
        // Optimized branch for common case: kernel_size == 3, stride == 1, padding == 0.
        if (kernel_size == 3 && stride == 1 && padding == 0) {
            int w = in_w;  // Cache input width.
            int base_input = ((n * channels + c) * in_h + out_y) * w + out_x;
            int base_weight = c * 9;  // 3x3 kernel has 9 elements.
            sum =  __ldg(&weight[base_weight + 0]) * __ldg(&input[base_input + 0]) +
                   __ldg(&weight[base_weight + 1]) * __ldg(&input[base_input + 1]) +
                   __ldg(&weight[base_weight + 2]) * __ldg(&input[base_input + 2]) +
                   __ldg(&weight[base_weight + 3]) * __ldg(&input[base_input + w]) +
                   __ldg(&weight[base_weight + 4]) * __ldg(&input[base_input + w + 1]) +
                   __ldg(&weight[base_weight + 5]) * __ldg(&input[base_input + w + 2]) +
                   __ldg(&weight[base_weight + 6]) * __ldg(&input[base_input + 2*w]) +
                   __ldg(&weight[base_weight + 7]) * __ldg(&input[base_input + 2*w + 1]) +
                   __ldg(&weight[base_weight + 8]) * __ldg(&input[base_input + 2*w + 2]);
        } else {
            // General case: use nested loops with boundary checks.
            for (int i = 0; i < kernel_size; i++) {
                for (int j = 0; j < kernel_size; j++) {
                    int in_y = out_y * stride - padding + i;
                    int in_x = out_x * stride - padding + j;
                    if (in_y >= 0 && in_y < in_h && in_x >= 0 && in_x < in_w) {
                        int input_index = ((n * channels + c) * in_h + in_y) * in_w + in_x;
                        int weight_index = c * (kernel_size * kernel_size) + i * kernel_size + j;
                        sum += __ldg(&weight[weight_index]) * __ldg(&input[input_index]);
                    }
                }
            }
        }
        output[index] = sum;
    }
}

torch::Tensor depthwise_conv2d(torch::Tensor input,
                               torch::Tensor weight,
                               int stride,
                               int padding,
                               int kernel_size) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Input shape: [batch, channels, in_h, in_w]
    int batch = input.size(0);
    int channels = input.size(1);
    int in_h = input.size(2);
    int in_w = input.size(3);
    
    // Compute output dimensions.
    int out_h = (in_h + 2 * padding - kernel_size) / stride + 1;
    int out_w = (in_w + 2 * padding - kernel_size) / stride + 1;
    
    auto output = torch::empty({batch, channels, out_h, out_w}, input.options());
    
    int total = batch * channels * out_h * out_w;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    depthwise_conv2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        output.data_ptr<float>(),
        batch,
        channels,
        in_h,
        in_w,
        out_h,
        out_w,
        kernel_size,
        stride,
        padding
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("depthwise_conv2d", &depthwise_conv2d, "Depthwise 2D convolution (CUDA) with loop unrolling for 3x3 kernels");
}