2.90. 3D Convolution with Activation and Clamping
Performs a 3D convolution on input data using a specified kernel, applies a LeakyReLU activation, adds a fixed tensor to the convolution result, clamps the result between -1 and 1, and finally computes the GELU activation on the data.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can improve the kernel performance for the typical case of kernel size 3 and in_channels 3 by fully unrolling the convolution loops over the spatial dimensions and channels. Our previous version already had an unrolled branch for k==3; now we add an inner branch when in_channels==3 so that we unroll the channel loop as well. This minimizes loop control overhead and maximizes instruction-level parallelism. The subsequent operations (LeakyReLU, bias addition, addition of the broadcast sum_tensor, clamping, and GELU activation) remain the same. Our Python module remains unchanged except that it calls our updated fused CUDA operator. This change should yield a further speedup over our previous version without compromising correctness. Below are our updated Python and CUDA modules.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel: performs 3D convolution, then applies:
// - LeakyReLU (negative_slope=0.2)
// - Adds conv_bias and a broadcast sum_tensor (indexed by output channel)
// - Clamps the result between -1.0 and 1.0
// - Applies approximate GELU activation.
__global__ void fused_conv_activation_kernel(const float* __restrict__ input,
const float* __restrict__ weight,
const float* __restrict__ conv_bias,
const float* __restrict__ sum_tensor,
float* __restrict__ output,
int batch_size,
int in_channels,
int D_in, int H_in, int W_in,
int out_channels,
int k)
{
// Compute output dimensions (stride=1, no padding)
int D_out = D_in - k + 1;
int H_out = H_in - k + 1;
int W_out = W_in - k + 1;
// Total number of output elements
int total = batch_size * out_channels * D_out * H_out * W_out;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total)
{
// Decode flattened index into (b, oc, d, h, w)
int temp = idx;
int w = temp % W_out;
temp /= W_out;
int h = temp % H_out;
temp /= H_out;
int d = temp % D_out;
temp /= D_out;
int oc = temp % out_channels;
int b = temp / out_channels;
float sum_val = 0.0f;
// If kernel size==3 and in_channels==3, use fully unrolled version.
if (k == 3 && in_channels == 3)
{
int base_input = b * (in_channels * D_in * H_in * W_in);
int base0 = base_input; // channel 0
int base1 = base0 + (D_in * H_in * W_in); // channel 1
int base2 = base1 + (D_in * H_in * W_in); // channel 2
int base_weight = oc * (in_channels * 27); // 27 = 3*3*3
// Unroll channel 0
sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight + 0];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight + 1];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight + 2];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight + 3];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight + 4];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight + 5];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight + 6];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight + 7];
sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight + 8];
// kd = 1 for channel 0
sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight + 9];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight + 10];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight + 11];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight + 12];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight + 13];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight + 14];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight + 15];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight + 16];
sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight + 17];
// kd = 2 for channel 0
sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight + 18];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight + 19];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight + 20];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight + 21];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight + 22];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight + 23];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight + 24];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight + 25];
sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight + 26];
// Unrolled for channel 1
int base_weight1 = base_weight + 27;
sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 + 0];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 + 1];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 + 2];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 + 3];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 + 4];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 + 5];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 + 6];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 + 7];
sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 + 8];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 + 9];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 + 10];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 + 11];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 + 12];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 + 13];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 + 14];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 + 15];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 + 16];
sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 + 17];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 + 18];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 + 19];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 + 20];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 + 21];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 + 22];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 + 23];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 + 24];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 + 25];
sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 + 26];
// Unrolled for channel 2
int base_weight2 = base_weight + 54;
sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 + 0];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 + 1];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 + 2];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 + 3];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 + 4];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 + 5];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 + 6];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 + 7];
sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 + 8];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 + 9];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 + 10];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 + 11];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 + 12];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 + 13];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 + 14];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 + 15];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 + 16];
sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 + 17];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 + 18];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 + 19];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 + 20];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 + 21];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 + 22];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 + 23];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 + 24];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 + 25];
sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 + 26];
// Combine the contributions from channels 0, 1, and 2.
// (sum_val already accumulated all terms)
}
else if (k == 3) {
// For kernel size 3 but general in_channels.
for (int c = 0; c < in_channels; ++c) {
for (int kd = 0; kd < k; ++kd) {
for (int kh = 0; kh < k; ++kh) {
for (int kw = 0; kw < k; ++kw) {
int d_in = d + kd;
int h_in = h + kh;
int w_in = w + kw;
int input_idx = b * (in_channels * D_in * H_in * W_in) +
c * (D_in * H_in * W_in) +
(d + kd) * (H_in * W_in) +
(h + kh) * W_in +
(w + kw);
int weight_idx = oc * (in_channels * k * k * k) +
c * (k * k * k) +
kd * (k * k) +
kh * k +
kw;
sum_val += input[input_idx] * weight[weight_idx];
}
}
}
}
}
else {
// General convolution loop for kernel sizes other than 3.
for (int c = 0; c < in_channels; ++c) {
for (int kd = 0; kd < k; ++kd) {
for (int kh = 0; kh < k; ++kh) {
for (int kw = 0; kw < k; ++kw) {
int d_in = d + kd;
int h_in = h + kh;
int w_in = w + kw;
int input_idx = b * (in_channels * D_in * H_in * W_in) +
c * (D_in * H_in * W_in) +
(d + kd) * (H_in * W_in) +
(h + kh) * W_in +
(w + kw);
int weight_idx = oc * (in_channels * k * k * k) +
c * (k * k * k) +
kd * (k * k) +
kh * k +
kw;
sum_val += input[input_idx] * weight[weight_idx];
}
}
}
}
}
// Add convolution bias for this output channel.
sum_val += conv_bias[oc];
// Apply LeakyReLU with negative_slope = 0.2.
if (sum_val < 0)
sum_val *= 0.2f;
// Add broadcasted sum_tensor value (per output channel).
sum_val += sum_tensor[oc];
// Clamp the result between -1.0 and 1.0.
if (sum_val > 1.0f)
sum_val = 1.0f;
else if (sum_val < -1.0f)
sum_val = -1.0f;
// Apply approximate GELU activation.
float x_cube = sum_val * sum_val * sum_val;
float tanh_out = tanhf(0.7978845608028654f * (sum_val + 0.044715f * x_cube));
float gelu = 0.5f * sum_val * (1.0f + tanh_out);
output[idx] = gelu;
}
}
torch::Tensor forward(torch::Tensor input,
torch::Tensor weight,
torch::Tensor conv_bias,
torch::Tensor sum_tensor) {
c10::cuda::CUDAGuard device_guard(input.device());
int batch_size = input.size(0);
int in_channels = input.size(1);
int D_in = input.size(2);
int H_in = input.size(3);
int W_in = input.size(4);
int out_channels = weight.size(0);
int k = weight.size(2); // assuming cubic kernel
int D_out = D_in - k + 1;
int H_out = H_in - k + 1;
int W_out = W_in - k + 1;
auto output = torch::empty({batch_size, out_channels, D_out, H_out, W_out}, input.options());
int total = batch_size * out_channels * D_out * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_conv_activation_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
conv_bias.data_ptr<float>(),
sum_tensor.data_ptr<float>(),
output.data_ptr<float>(),
batch_size,
in_channels,
D_in, H_in, W_in,
out_channels,
k
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused 3D conv with bias, LeakyReLU, addition, clamp, and GELU activation");
}