2.71. Conv2D Divide and LeakyReLU
Calculates a 2D convolution on the input data, divides the convolution output by a constant value, and applies the LeakyReLU activation with a negative slope of 0.01.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused convolution + division + LeakyReLU kernel by unrolling the inner loops. In our typical use case the kernel size and number of input channels are small (e.g. kernel_size = 3 and in_channels = 3), so unrolling these loops reduces loop overhead and allows the compiler to optimize index computations. We can also precompute strides such as in_h * in_w and avoid redundant multiplications in the loop. These improvements reduce arithmetic and control overhead and should genuinely improve performance without compromising correctness. We now update the CUDA kernel accordingly while leaving the Python module largely unchanged.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Fused CUDA kernel that performs convolution, division, and LeakyReLU activation.
// Assumes stride=1 and no padding.
__global__ void fused_conv_div_leakyrelu_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ out,
int batch,
int in_channels,
int in_h,
int in_w,
int out_channels,
int kernel_size,
float divisor,
float negative_slope) {
// Compute output dimensions.
int out_h = in_h - kernel_size + 1;
int out_w = in_w - kernel_size + 1;
// Total number of output elements.
int total = batch * out_channels * out_h * out_w;
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total) {
// Decode indices: b, oc, oh, ow.
int temp = index;
int ow = temp % out_w;
temp /= out_w;
int oh = temp % out_h;
temp /= out_h;
int oc = temp % out_channels;
int b = temp / out_channels;
// Precompute common strides.
int in_stride = in_h * in_w;
int weight_channel_stride = kernel_size * kernel_size;
// Initialize sum with bias.
float sum = bias[oc];
// Fused convolution:
// Use pragma unroll to help the compiler unroll small loops.
#pragma unroll
for (int ic = 0; ic < in_channels; ic++) {
// Offset for input channel ic in batch b.
int base_x = b * (in_channels * in_stride) + ic * in_stride;
// Offset for weight for output channel oc and input channel ic.
int base_w = oc * (in_channels * weight_channel_stride) + ic * weight_channel_stride;
#pragma unroll
for (int kh = 0; kh < kernel_size; kh++) {
int in_row = oh + kh;
#pragma unroll
for (int kw = 0; kw < kernel_size; kw++) {
int in_col = ow + kw;
int x_index = base_x + in_row * in_w + in_col;
int w_index = base_w + kh * kernel_size + kw;
sum += x[x_index] * weight[w_index];
}
}
}
// Division by constant.
sum = sum / divisor;
// LeakyReLU activation.
float activated = (sum >= 0.0f) ? sum : sum * negative_slope;
out[index] = activated;
}
}
torch::Tensor fused_conv_div_leakyrelu(torch::Tensor x,
torch::Tensor weight,
torch::Tensor bias,
float divisor) {
c10::cuda::CUDAGuard device_guard(x.device());
// Get input dimensions.
int batch = x.size(0);
int in_channels = x.size(1);
int in_h = x.size(2);
int in_w = x.size(3);
// Get convolution kernel parameters.
int out_channels = weight.size(0);
int kernel_size = weight.size(2); // weight assumed to have shape [out_channels, in_channels, kernel_size, kernel_size]
// Compute output dimensions.
int out_h = in_h - kernel_size + 1;
int out_w = in_w - kernel_size + 1;
// Allocate output tensor.
auto output = torch::empty({batch, out_channels, out_h, out_w}, x.options());
int total = batch * out_channels * out_h * out_w;
int threads = 256;
int blocks = (total + threads - 1) / threads;
float negative_slope = 0.01f;
fused_conv_div_leakyrelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
batch,
in_channels,
in_h,
in_w,
out_channels,
kernel_size,
divisor,
negative_slope
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_conv_div_leakyrelu", &fused_conv_div_leakyrelu, "Fused convolution, division, and LeakyReLU activation");
}