2.7. 3D Convolution with Sequential Activations

Performs a 3D convolution on the input data, then sequentially applies ReLU, LeakyReLU, GELU, and Sigmoid activation functions, and finally adds a bias to the resulting features.

Progress

Step 6 of 6
Reward
1.81x

Optimization Timeline

Step 6 Details

Reasoning

Our current implementation has a well-optimized fused activation and bias addition kernel. We use vectorized loads and stores (float4), grid-stride loops, fast math intrinsics (__expf and tanhf) and load the bias via __ldg to utilize the read-only cache. One potential further improvement is to consider using warp-level primitives, but this would add complexity and the gains are likely marginal compared to our current performance. Our current version already achieves a speedup of ~1.81 over the native sequential PyTorch implementation. Therefore, we will retain the current design with these micro-optimizations.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fast math inlined device functions.
__device__ inline float relu_func(float x) {
    return x > 0.0f ? x : 0.0f;
}

__device__ inline float gelu_func(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    // Approximate GELU: 0.5 * x * (1 + tanhf(sqrt(2/pi) * (x + 0.044715 * x^3)))
    return 0.5f * x * (1.0f + tanhf(sqrt_2_over_pi * (x + 0.044715f * x * x * x)));
}

__device__ inline float sigmoid_func(float x) {
    return 1.0f / (1.0f + __expf(-x));
}

// Scalar kernel: each thread processes one element.
__global__ void fused_activation_bias_kernel(const float* __restrict__ x,
                                               float* __restrict__ out,
                                               const float* __restrict__ bias,
                                               int N, int C, int D, int H, int W) {
    int total = N * C * D * H * W;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int spatial = D * H * W;
    for (; idx < total; idx += blockDim.x * gridDim.x) {
        int c = (idx / spatial) % C;
        float a = x[idx];
        float r = relu_func(a);
        float g = gelu_func(r);
        float s = sigmoid_func(g);
        // Load bias using __ldg to use read-only caching.
        out[idx] = s + __ldg(&bias[c]);
    }
}

// Vectorized kernel: processes 4 elements simultaneously using float4.
__global__ void fused_activation_bias_vec_kernel(const float4* __restrict__ x,
                                                   float4* __restrict__ out,
                                                   const float* __restrict__ bias,
                                                   int total_vec, int C, int spatial) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    for (; idx < total_vec; idx += blockDim.x * gridDim.x) {
        int base = idx * 4;
        float4 vec = x[idx];
        float vals[4] = {vec.x, vec.y, vec.z, vec.w};
        float results[4];
        #pragma unroll
        for (int i = 0; i < 4; i++) {
            int global_idx = base + i;
            int c = (global_idx / spatial) % C;
            float a = vals[i];
            float r = a > 0.0f ? a : 0.0f;
            float g = 0.5f * r * (1.0f + tanhf(0.7978845608028654f * (r + 0.044715f * r * r * r)));
            float s = 1.0f / (1.0f + __expf(-g));
            results[i] = s + __ldg(&bias[c]);
        }
        float4 vec_out = make_float4(results[0], results[1], results[2], results[3]);
        out[idx] = vec_out;
    }
}

torch::Tensor fused_forward(torch::Tensor x, torch::Tensor bias) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int N = x.size(0);
    int C = x.size(1);
    int D = x.size(2);
    int H = x.size(3);
    int W = x.size(4);
    int total = N * C * D * H * W;
    int spatial = D * H * W;
    
    auto output = torch::empty_like(x);
    int threads = 256;
    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    
    if (total % 4 == 0) {
        int total_vec = total / 4;
        int blocks = (total_vec + threads - 1) / threads;
        const float4* x_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
        float4* out_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        fused_activation_bias_vec_kernel<<<blocks, threads, 0, stream>>>(x_vec, out_vec, bias.data_ptr<float>(), total_vec, C, spatial);
    } else {
        int blocks = (total + threads - 1) / threads;
        fused_activation_bias_kernel<<<blocks, threads, 0, stream>>>(x.data_ptr<float>(), output.data_ptr<float>(), bias.data_ptr<float>(), N, C, D, H, W);
    }
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_forward", &fused_forward, "Fused ReLU, GELU, Sigmoid activation and bias addition with vectorized path and bias loaded via __ldg");
}