2.48. 3D Convolution with Scaling and Activation

Calculates a 3D convolution on input data, scales the convolution output element-wise using learned parameters, applies a hyperbolic tangent activation, multiplies the result by another scaling parameter, and finally applies a sigmoid function to produce the final transformed output.

Progress

Step 6 of 6
Reward
1.66x

Optimization Timeline

Step 6 Details

Reasoning

We reviewed our fused post‐processing kernel that scales the convolution output, applies a tanh activation, multiplies by a bias, and then applies a sigmoid activation. In our current revision we refactored the inner computation into an inline device function (fused_activation) so that the code is more compact and easier for the compiler to inline and optimize. We support both element‐wise and vectorized (using float4) implementations when the spatial dimensions allow vectorization. With careful ordering of operations, use of __ldg for caching constants, and the inlining of the core activation math, our improvements reduce instruction count and memory latency. This strategy genuinely increases throughput while preserving numerical correctness.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Inline device function to perform the fused activation operation.
// Computes: sigmoid(tanh(x * scaling) * bias)
__device__ inline float fused_activation(float x, float scaling, float bias) {
    float temp = x * scaling;
    temp = tanhf(temp);
    temp *= bias;
    return 1.0f / (1.0f + __expf(-temp));
}

// Element-wise kernel: each thread processes one element.
__global__ void fused_ops_kernel(const float* __restrict__ input,
                                 float* __restrict__ output,
                                 const float* __restrict__ scaling,
                                 const float* __restrict__ bias,
                                 int N, int C, int D, int H, int W) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * C * D * H * W;
    if (index < total) {
        int spatial = D * H * W;
        int rem = index % (C * spatial);
        int c = rem / spatial;
        float x = input[index];
        // Use __ldg for reading constant parameters from read-only cache.
        float s = __ldg(&scaling[c]);
        float b = __ldg(&bias[c]);
        output[index] = fused_activation(x, s, b);
    }
}

// Vectorized kernel using float4: each thread processes 4 contiguous elements.
// Assumes the spatial dimension (D*H*W) is divisible by 4.
__global__ void fused_ops_vec_kernel(const float4* __restrict__ input,
                                     float4* __restrict__ output,
                                     const float* __restrict__ scaling,
                                     const float* __restrict__ bias,
                                     int N, int C, int D, int H, int W,
                                     int spatial) {
    int vec_per_channel = spatial / 4;
    int total_vec = N * C * vec_per_channel;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < total_vec) {
        int channel_index = (index / vec_per_channel) % C;
        float a = __ldg(&scaling[channel_index]);
        float b = __ldg(&bias[channel_index]);
        
        float4 in_vec = input[index];
        float4 out_vec;
        out_vec.x = fused_activation(in_vec.x, a, b);
        out_vec.y = fused_activation(in_vec.y, a, b);
        out_vec.z = fused_activation(in_vec.z, a, b);
        out_vec.w = fused_activation(in_vec.w, a, b);
        output[index] = out_vec;
    }
}

torch::Tensor forward(torch::Tensor input, torch::Tensor scaling, torch::Tensor bias) {
    // Enforce using the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Extract dimensions: input shape is [N, C, D, H, W].
    int N = input.size(0);
    int C = input.size(1);
    int D = input.size(2);
    int H = input.size(3);
    int W = input.size(4);
    
    auto output = torch::empty({N, C, D, H, W}, input.options());
    int total = N * C * D * H * W;
    int threads = 256;
    int spatial = D * H * W;
    
    // Dispatch vectorized kernel if spatial size is divisible by 4.
    if ((spatial % 4) == 0) {
        int vec_per_channel = spatial / 4;
        int total_vec = N * C * vec_per_channel;
        int blocks = (total_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        fused_ops_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec, output_vec, scaling.data_ptr<float>(), bias.data_ptr<float>(),
            N, C, D, H, W, spatial);
    } else {
        int blocks = (total + threads - 1) / threads;
        fused_ops_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(), output.data_ptr<float>(), 
            scaling.data_ptr<float>(), bias.data_ptr<float>(),
            N, C, D, H, W);
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused conv post-ops: scale, tanh, multiply bias, and sigmoid (vectorized if possible)");
}