1.87. Pointwise 2D Convolution
Calculates a pointwise 2D convolution using a 1x1 kernel, transforming an input tensor of shape (batch_size, in_channels, height, width) into an output tensor of shape (batch_size, out_channels, height, width) while optionally adding a bias.
Progress
Optimization Timeline
Step 7 Details
Reasoning
We have already implemented a pixel-level kernel for small in_channels (≤4) that caches the weight matrix in shared memory. This dramatically reduces redundant global memory loads for the typical use-case (e.g. RGB images). In the general kernel, we continued to use per-output element parallelism with __ldg caching. One further potential improvement would be to use vectorized loads for the input if alignment guarantees are available, but that would require more drastic data layout changes. Our current approach is already very efficient, and the measured runtime (≈0.28 sec vs. native 0.3 sec with a speedup ratio of ≈1.07) indicates a genuine benefit. Before finalizing, I ask myself: Have these modifications been genuine improvements? Yes, by using shared memory caching in the pixel kernel and avoiding redundant loads, we have improved performance especially when in_channels is small. This is a standard and effective approach rather than a hack. I will now present the final version with our optimized CUDA implementation, while the Python binding remains essentially unchanged.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Pixel kernel for small in_channels (<= 4)
// Each thread processes one pixel (b, y, x) and computes all output channels.
// This kernel uses shared memory to cache the weight matrix.
__global__ void pointwise_conv2d_kernel_pixel(const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
int batch,
int in_channels,
int out_channels,
int height,
int width) {
extern __shared__ float s_weight[]; // Shared memory for weight matrix [out_channels * in_channels]
int weight_elements = out_channels * in_channels;
// Load weight matrix into shared memory (each thread cooperatively loads part of it)
for (int i = threadIdx.x; i < weight_elements; i += blockDim.x) {
s_weight[i] = weight[i];
}
__syncthreads();
int pixel_count = batch * height * width;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < pixel_count) {
int HW = height * width;
int b = idx / HW;
int rem = idx % HW;
int y = rem / width;
int x = rem % width;
// Base indices for this input pixel and output pixel.
int input_pixel_base = b * (in_channels * HW) + y * width + x;
int output_pixel_base = b * (out_channels * HW) + y * width + x;
// Compute the dot-product for every output channel.
for (int o = 0; o < out_channels; o++) {
float sum = 0.0f;
int weight_base = o * in_channels; // weight is stored in row-major order in shared memory
#pragma unroll
for (int c = 0; c < in_channels; c++) {
int input_idx = input_pixel_base + c * HW;
float in_val = __ldg(&input[input_idx]);
float w_val = s_weight[weight_base + c];
sum += in_val * w_val;
}
int out_idx = output_pixel_base + o * HW;
output[out_idx] = sum;
}
}
}
// General kernel for arbitrary in_channels using per-output-element parallelism.
// Each thread computes one output element.
__global__ void pointwise_conv2d_kernel_general(const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
int batch,
int in_channels,
int out_channels,
int height,
int width) {
int total = batch * out_channels * height * width;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
int x = idx % width;
int temp = idx / width;
int y = temp % height;
temp = temp / height;
int o = temp % out_channels;
int b = temp / out_channels;
float sum = 0.0f;
int chan_stride = height * width; // Precomputed channel stride.
int input_base = b * in_channels * chan_stride + y * width + x;
for (int c = 0; c < in_channels; c++) {
int input_idx = input_base + c * chan_stride;
int weight_idx = o * in_channels + c;
sum += __ldg(&input[input_idx]) * __ldg(&weight[weight_idx]);
}
output[idx] = sum;
}
}
torch::Tensor pointwise_conv2d(torch::Tensor input, torch::Tensor weight) {
c10::cuda::CUDAGuard device_guard(input.device());
int batch = input.size(0);
int in_channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
int out_channels = weight.size(0);
auto output = torch::empty({batch, out_channels, height, width}, input.options());
if (in_channels <= 4) {
int pixel_count = batch * height * width;
int threads = 256;
int blocks = (pixel_count + threads - 1) / threads;
// Compute shared memory size for weight matrix.
int shared_mem = out_channels * in_channels * sizeof(float);
pointwise_conv2d_kernel_pixel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
batch,
in_channels,
out_channels,
height,
width
);
} else {
int total = batch * out_channels * height * width;
int threads = 256;
int blocks = (total + threads - 1) / threads;
pointwise_conv2d_kernel_general<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
batch,
in_channels,
out_channels,
height,
width
);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pointwise_conv2d", &pointwise_conv2d, "Optimized fused pointwise 2D convolution (1x1) using pixel-level parallelism and shared memory caching of weights for small in_channels");
}