2.90. 3D Convolution with Activation and Clamping

Performs a 3D convolution on input data using a specified kernel, applies a LeakyReLU activation, adds a fixed tensor to the convolution result, clamps the result between -1 and 1, and finally computes the GELU activation on the data.

Progress

Step 3 of 3
Reward
1.08x

Optimization Timeline

Step 3 Details

Reasoning

We can improve the kernel performance for the typical case of kernel size 3 and in_channels 3 by fully unrolling the convolution loops over the spatial dimensions and channels. Our previous version already had an unrolled branch for k==3; now we add an inner branch when in_channels==3 so that we unroll the channel loop as well. This minimizes loop control overhead and maximizes instruction-level parallelism. The subsequent operations (LeakyReLU, bias addition, addition of the broadcast sum_tensor, clamping, and GELU activation) remain the same. Our Python module remains unchanged except that it calls our updated fused CUDA operator. This change should yield a further speedup over our previous version without compromising correctness. Below are our updated Python and CUDA modules.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel: performs 3D convolution, then applies:
//   - LeakyReLU (negative_slope=0.2)
//   - Adds conv_bias and a broadcast sum_tensor (indexed by output channel)
//   - Clamps the result between -1.0 and 1.0
//   - Applies approximate GELU activation.
__global__ void fused_conv_activation_kernel(const float* __restrict__ input,
                                               const float* __restrict__ weight,
                                               const float* __restrict__ conv_bias,
                                               const float* __restrict__ sum_tensor,
                                               float* __restrict__ output,
                                               int batch_size,
                                               int in_channels,
                                               int D_in, int H_in, int W_in,
                                               int out_channels,
                                               int k)
{
    // Compute output dimensions (stride=1, no padding)
    int D_out = D_in - k + 1;
    int H_out = H_in - k + 1;
    int W_out = W_in - k + 1;
    
    // Total number of output elements
    int total = batch_size * out_channels * D_out * H_out * W_out;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (idx < total)
    {
        // Decode flattened index into (b, oc, d, h, w)
        int temp = idx;
        int w = temp % W_out;
        temp /= W_out;
        int h = temp % H_out;
        temp /= H_out;
        int d = temp % D_out;
        temp /= D_out;
        int oc = temp % out_channels;
        int b = temp / out_channels;
        
        float sum_val = 0.0f;
        
        // If kernel size==3 and in_channels==3, use fully unrolled version.
        if (k == 3 && in_channels == 3)
        {
            int base_input = b * (in_channels * D_in * H_in * W_in);
            int base0 = base_input;  // channel 0
            int base1 = base0 + (D_in * H_in * W_in); // channel 1
            int base2 = base1 + (D_in * H_in * W_in); // channel 2
            int base_weight = oc * (in_channels * 27); // 27 = 3*3*3
            
            // Unroll channel 0
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight +  0];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight +  1];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight +  2];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight +  3];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight +  4];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight +  5];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight +  6];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight +  7];
            sum_val += input[base0 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight +  8];
            
            // kd = 1 for channel 0
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight +  9];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight + 10];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight + 11];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight + 12];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight + 13];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight + 14];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight + 15];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight + 16];
            sum_val += input[base0 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight + 17];
            
            // kd = 2 for channel 0
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight + 18];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight + 19];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight + 20];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight + 21];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight + 22];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight + 23];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight + 24];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight + 25];
            sum_val += input[base0 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight + 26];
            
            // Unrolled for channel 1
            int base_weight1 = base_weight + 27;
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 +  0];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 +  1];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 +  2];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 +  3];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 +  4];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 +  5];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 +  6];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 +  7];
            sum_val += input[base1 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 +  8];
            
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 +  9];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 + 10];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 + 11];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 + 12];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 + 13];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 + 14];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 + 15];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 + 16];
            sum_val += input[base1 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 + 17];
            
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight1 + 18];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight1 + 19];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight1 + 20];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight1 + 21];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight1 + 22];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight1 + 23];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight1 + 24];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight1 + 25];
            sum_val += input[base1 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight1 + 26];
            
            // Unrolled for channel 2
            int base_weight2 = base_weight + 54;
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 +  0];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 +  1];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 +  2];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 +  3];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 +  4];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 +  5];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 +  6];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 +  7];
            sum_val += input[base2 + (d+0)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 +  8];
            
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 +  9];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 + 10];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 + 11];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 + 12];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 + 13];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 + 14];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 + 15];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 + 16];
            sum_val += input[base2 + (d+1)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 + 17];
            
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+0)] * weight[base_weight2 + 18];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+1)] * weight[base_weight2 + 19];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+0)*W_in + (w+2)] * weight[base_weight2 + 20];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+0)] * weight[base_weight2 + 21];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+1)] * weight[base_weight2 + 22];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+1)*W_in + (w+2)] * weight[base_weight2 + 23];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+0)] * weight[base_weight2 + 24];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+1)] * weight[base_weight2 + 25];
            sum_val += input[base2 + (d+2)*H_in*W_in + (h+2)*W_in + (w+2)] * weight[base_weight2 + 26];
            
            // Combine the contributions from channels 0, 1, and 2.
            // (sum_val already accumulated all terms)
        }
        else if (k == 3) {
            // For kernel size 3 but general in_channels.
            for (int c = 0; c < in_channels; ++c) {
                for (int kd = 0; kd < k; ++kd) {
                    for (int kh = 0; kh < k; ++kh) {
                        for (int kw = 0; kw < k; ++kw) {
                            int d_in = d + kd;
                            int h_in = h + kh;
                            int w_in = w + kw;
                            int input_idx = b * (in_channels * D_in * H_in * W_in) +
                                            c * (D_in * H_in * W_in) +
                                            (d + kd) * (H_in * W_in) +
                                            (h + kh) * W_in +
                                            (w + kw);
                            int weight_idx = oc * (in_channels * k * k * k) +
                                             c * (k * k * k) +
                                             kd * (k * k) +
                                             kh * k +
                                             kw;
                            sum_val += input[input_idx] * weight[weight_idx];
                        }
                    }
                }
            }
        }
        else {
            // General convolution loop for kernel sizes other than 3.
            for (int c = 0; c < in_channels; ++c) {
                for (int kd = 0; kd < k; ++kd) {
                    for (int kh = 0; kh < k; ++kh) {
                        for (int kw = 0; kw < k; ++kw) {
                            int d_in = d + kd;
                            int h_in = h + kh;
                            int w_in = w + kw;
                            int input_idx = b * (in_channels * D_in * H_in * W_in) +
                                            c * (D_in * H_in * W_in) +
                                            (d + kd) * (H_in * W_in) +
                                            (h + kh) * W_in +
                                            (w + kw);
                            int weight_idx = oc * (in_channels * k * k * k) +
                                             c * (k * k * k) +
                                             kd * (k * k) +
                                             kh * k +
                                             kw;
                            sum_val += input[input_idx] * weight[weight_idx];
                        }
                    }
                }
            }
        }
        
        // Add convolution bias for this output channel.
        sum_val += conv_bias[oc];
        
        // Apply LeakyReLU with negative_slope = 0.2.
        if (sum_val < 0)
            sum_val *= 0.2f;
        
        // Add broadcasted sum_tensor value (per output channel).
        sum_val += sum_tensor[oc];
        
        // Clamp the result between -1.0 and 1.0.
        if (sum_val > 1.0f)
            sum_val = 1.0f;
        else if (sum_val < -1.0f)
            sum_val = -1.0f;
        
        // Apply approximate GELU activation.
        float x_cube = sum_val * sum_val * sum_val;
        float tanh_out = tanhf(0.7978845608028654f * (sum_val + 0.044715f * x_cube));
        float gelu = 0.5f * sum_val * (1.0f + tanh_out);
        
        output[idx] = gelu;
    }
}

torch::Tensor forward(torch::Tensor input,
                      torch::Tensor weight,
                      torch::Tensor conv_bias,
                      torch::Tensor sum_tensor) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    int batch_size = input.size(0);
    int in_channels = input.size(1);
    int D_in = input.size(2);
    int H_in = input.size(3);
    int W_in = input.size(4);
    
    int out_channels = weight.size(0);
    int k = weight.size(2); // assuming cubic kernel
    
    int D_out = D_in - k + 1;
    int H_out = H_in - k + 1;
    int W_out = W_in - k + 1;
    
    auto output = torch::empty({batch_size, out_channels, D_out, H_out, W_out}, input.options());
    
    int total = batch_size * out_channels * D_out * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_conv_activation_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        conv_bias.data_ptr<float>(),
        sum_tensor.data_ptr<float>(),
        output.data_ptr<float>(),
        batch_size,
        in_channels,
        D_in, H_in, W_in,
        out_channels,
        k
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused 3D conv with bias, LeakyReLU, addition, clamp, and GELU activation");
}