1.82. Depthwise 2D Convolution (Square Input/Kernel)
Performs a depthwise convolution on a square input tensor using a square kernel. It calculates channel-wise outputs by convolving each input channel independently, with parameters for adjustable stride, padding, and an optional bias.
Progress
Optimization Timeline
Step 7 Details
Reasoning
We considered exploring shared memory tiling and warp-level primitives for further improvements, but the increased complexity and potential for register pressure did not promise a clear win over our current highly optimized kernel. Our current implementation already provides a fast-path for the common case (kernel_size == 3, stride == 1, padding == 0) using complete loop unrolling, register caching for weight values, and coalesced global reads with __ldg in the generic path. These enhancements have proven to yield a significant speedup (~1.72×) compared to the native PyTorch implementation. The potential additional optimizations would require nontrivial code changes and may not be beneficial for our typical use-cases. This final version balances complexity and performance while remaining mathematically equivalent and efficient.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Depthwise 2D convolution kernel with loop unrolling, register caching, and optional unrolling directives.
// Input: [batch, channels, H, W]
// Weight: [channels, 1, K, K] where K = kernel_size.
// Output: [batch, channels, H_out, W_out]
// H_out = (H + 2*padding - K) / stride + 1, similarly for W_out.
__global__ void depthwise_conv2d_kernel(const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
int batch,
int channels,
int H,
int W,
int H_out,
int W_out,
int K,
int stride,
int padding) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * channels * H_out * W_out;
if (index >= total)
return;
// Map linear index to (b, c, h_out, w_out)
int w_out = index % W_out;
int tmp = index / W_out;
int h_out = tmp % H_out;
tmp /= H_out;
int c = tmp % channels;
int b = tmp / channels;
float sum = 0.0f;
int h_in_start = h_out * stride - padding;
int w_in_start = w_out * stride - padding;
// Optimized branch for kernel_size == 3, stride == 1, and padding == 0.
if (K == 3 && stride == 1 && padding == 0) {
// Load weight elements for channel c into registers.
float w00 = weight[((c * 3) + 0) * 3 + 0],
w01 = weight[((c * 3) + 0) * 3 + 1],
w02 = weight[((c * 3) + 0) * 3 + 2],
w10 = weight[((c * 3) + 1) * 3 + 0],
w11 = weight[((c * 3) + 1) * 3 + 1],
w12 = weight[((c * 3) + 1) * 3 + 2],
w20 = weight[((c * 3) + 2) * 3 + 0],
w21 = weight[((c * 3) + 2) * 3 + 1],
w22 = weight[((c * 3) + 2) * 3 + 2];
// In this ideal case, the input indices are valid without bounds checking.
int h0 = h_out, h1 = h_out + 1, h2 = h_out + 2;
int w0 = w_out, w1 = w_out + 1, w2 = w_out + 1 + 0; // w2 = w_out + 2
w2 = w_out + 2;
int base = ((b * channels + c) * H);
int offset0 = (base + h0) * W,
offset1 = (base + h1) * W,
offset2 = (base + h2) * W;
sum = input[offset0 + w0] * w00 +
input[offset0 + w1] * w01 +
input[offset0 + w2] * w02 +
input[offset1 + w0] * w10 +
input[offset1 + w1] * w11 +
input[offset1 + w2] * w12 +
input[offset2 + w0] * w20 +
input[offset2 + w1] * w21 +
input[offset2 + w2] * w22;
}
// Optimized branch for K == 3 with non-zero padding or stride != 1.
else if (K == 3) {
float w00 = weight[((c * 3) + 0) * 3 + 0],
w01 = weight[((c * 3) + 0) * 3 + 1],
w02 = weight[((c * 3) + 0) * 3 + 2],
w10 = weight[((c * 3) + 1) * 3 + 0],
w11 = weight[((c * 3) + 1) * 3 + 1],
w12 = weight[((c * 3) + 1) * 3 + 2],
w20 = weight[((c * 3) + 2) * 3 + 0],
w21 = weight[((c * 3) + 2) * 3 + 1],
w22 = weight[((c * 3) + 2) * 3 + 2];
int h0 = h_in_start + 0, h1 = h_in_start + 1, h2 = h_in_start + 2;
int w0 = w_in_start + 0, w1 = w_in_start + 1, w2 = w_in_start + 2;
int base = ((b * channels + c) * H);
if (h0 >= 0 && h0 < H) {
int offset = (base + h0) * W;
if (w0 >= 0 && w0 < W)
sum += input[offset + w0] * w00;
if (w1 >= 0 && w1 < W)
sum += input[offset + w1] * w01;
if (w2 >= 0 && w2 < W)
sum += input[offset + w2] * w02;
}
if (h1 >= 0 && h1 < H) {
int offset = (base + h1) * W;
if (w0 >= 0 && w0 < W)
sum += input[offset + w0] * w10;
if (w1 >= 0 && w1 < W)
sum += input[offset + w1] * w11;
if (w2 >= 0 && w2 < W)
sum += input[offset + w2] * w12;
}
if (h2 >= 0 && h2 < H) {
int offset = (base + h2) * W;
if (w0 >= 0 && w0 < W)
sum += input[offset + w0] * w20;
if (w1 >= 0 && w1 < W)
sum += input[offset + w1] * w21;
if (w2 >= 0 && w2 < W)
sum += input[offset + w2] * w22;
}
}
// Generic implementation for arbitrary kernel sizes.
else {
#pragma unroll
for (int i = 0; i < K; i++) {
int h_in = h_in_start + i;
if (h_in < 0 || h_in >= H) continue;
#pragma unroll
for (int j = 0; j < K; j++) {
int w_in = w_in_start + j;
if (w_in < 0 || w_in >= W) continue;
int input_idx = ((b * channels + c) * H + h_in) * W + w_in;
int weight_idx = ((c * K) + i) * K + j;
sum += __ldg(&input[input_idx]) * __ldg(&weight[weight_idx]);
}
}
}
output[index] = sum;
}
torch::Tensor depthwise_conv2d(torch::Tensor input,
torch::Tensor weight,
int stride,
int padding) {
c10::cuda::CUDAGuard device_guard(input.device());
int batch = input.size(0);
int channels = input.size(1);
int H = input.size(2);
int W = input.size(3);
int K = weight.size(2);
int H_out = (H + 2 * padding - K) / stride + 1;
int W_out = (W + 2 * padding - K) / stride + 1;
auto output = torch::empty({batch, channels, H_out, W_out}, input.options());
int total_output = batch * channels * H_out * W_out;
int threads = 256;
int blocks = (total_output + threads - 1) / threads;
depthwise_conv2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
batch,
channels,
H,
W,
H_out,
W_out,
K,
stride,
padding
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("depthwise_conv2d", &depthwise_conv2d, "Depthwise 2D Convolution (CUDA) with loop unrolling and register caching");
}