1.83. Depthwise Convolution with Asymmetric Kernel
Performs a depthwise 2D convolution on square input data using an asymmetric kernel sized (kernel_size, 1). Each input channel is convolved separately with configurable stride, padding, dilation, and an optional bias.
Progress
Optimization Timeline
Step 5 Details
Reasoning
We can further optimize the kernel by using __ldg to benefit from the read‐only cache when loading the input values. In our safe branch, where the convolution window is fully in bounds, we know that each load is safe, so we unroll the kernel loop explicitly for kernel_size==3 and use __ldg for each access. In the general safe branch, we also use __ldg for the convolution window. The unsafe branch already performs boundary checks; we add __ldg there as well when the index is valid. This approach minimizes memory access latency and improves throughput on modern GPUs. The Python module remains unchanged, as it only loads and calls the CUDA kernel via the extension.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Depthwise 2D convolution kernel for an asymmetric kernel of shape (kernel_size, 1).
// Optimization: if the convolution window is fully in bounds (safe branch) and when kernel_size==3,
// apply explicit loop unrolling with __ldg loads.
__global__ void depthwise_conv2d_kernel(const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
int B, int C, int H_in, int W_in,
int H_out, int W_out,
int kernel_size, int stride, int padding, int dilation) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * C * H_out * W_out;
if (index < total) {
// Decompose index into (b, c, h, w)
int tmp = index;
int w = tmp % W_out;
tmp /= W_out;
int h = tmp % H_out;
tmp /= H_out;
int c = tmp % C;
int b = tmp / C;
float sum = 0.0f;
// Starting input row index for this output.
int in_h_start = h * stride - padding;
// For kernel width 1, the input column index is constant.
int in_w = w * stride - padding;
// Each channel has its kernel stored contiguously; base weight index for channel c.
int weight_base = c * kernel_size;
// Safe branch: entire vertical window is within bounds.
if (in_h_start >= 0 && (in_h_start + dilation * (kernel_size - 1)) < H_in &&
in_w >= 0 && in_w < W_in) {
int base_idx = ((b * C + c) * H_in + in_h_start) * W_in + in_w;
if (kernel_size == 3) {
// Unrolled computation for kernel_size==3 using __ldg to load input.
sum = __ldg(&input[base_idx]) * weight[weight_base] +
__ldg(&input[base_idx + dilation * W_in]) * weight[weight_base + 1] +
__ldg(&input[base_idx + 2 * dilation * W_in]) * weight[weight_base + 2];
} else {
int idx_input = base_idx;
#pragma unroll
for (int k = 0; k < kernel_size; k++) {
sum += __ldg(&input[idx_input + k * dilation * W_in]) * weight[weight_base + k];
}
}
} else {
// Unsafe branch: each element is checked for bounds.
#pragma unroll
for (int k = 0; k < kernel_size; k++) {
int in_h = in_h_start + k * dilation;
if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) {
int idx_input = ((b * C + c) * H_in + in_h) * W_in + in_w;
sum += __ldg(&input[idx_input]) * weight[weight_base + k];
}
}
}
int output_idx = ((b * C + c) * H_out + h) * W_out + w;
output[output_idx] = sum;
}
}
torch::Tensor forward(torch::Tensor input, torch::Tensor weight, int stride, int padding, int dilation) {
// Enforce the operation on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
int B = input.size(0);
int C = input.size(1);
int H_in = input.size(2);
int W_in = input.size(3);
// Determine kernel_size from weight shape: [C, 1, kernel_size, 1].
int kernel_size = weight.size(2);
// Compute output dimensions using the standard formula.
int H_out = (H_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
int W_out = (W_in + 2 * padding - 1) / stride + 1;
auto output = torch::empty({B, C, H_out, W_out}, input.options());
int total_threads = B * C * H_out * W_out;
int threads = 256;
int blocks = (total_threads + threads - 1) / threads;
depthwise_conv2d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
B, C, H_in, W_in,
H_out, W_out,
kernel_size, stride, padding, dilation
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Custom depthwise conv2d with asymmetric kernel (kernel_size, 1)");
}