2.54. Convolution, Scaling, and Dual Activation

Performs a 2D convolution on the input data, scales the resulting feature maps by an element-wise multiplication with a learnable parameter, applies a LeakyReLU activation to introduce non-linearity, and finally calculates the Gaussian Error Linear Unit (GELU) activation to refine the output.

Progress

Step 3 of 3
Reward
1.47x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our fused CUDA kernel by replacing the standard tanhf function with the CUDA intrinsic __tanhf, which is usually faster because it is optimized at the hardware level. Also, we precompute the constants (0.7978845608f and 0.044715f) used in the GELU approximation so that they are not recomputed for each element. These optimizations, along with our vectorized implementation, enhance the performance by reducing computational overhead and ensuring more efficient math operations. Our Python wrapper remains the same while the CUDA code is updated accordingly. --- FIX --- The compilation error is caused by using the intrinsic __tanhf, which is not defined in device code. Instead, we should use the device-available tanhf function. In both the scalar and vectorized kernels, we will replace __tanhf with tanhf. This minor change leverages the device function while preserving our optimizations. Our Python code remains unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Scalar fused kernel: processes one element at a time.
// Applies multiplication by a scalar, LeakyReLU, and GELU approximated with a tanh-based function.
__global__ void fused_multiply_leakyrelu_gelu_kernel(const float* __restrict__ input,
                                                     const float* __restrict__ multiplier,
                                                     float* __restrict__ output,
                                                     int total,
                                                     int spatial_size, // H*W
                                                     int C,
                                                     float neg_slope) {
    const float A = 0.7978845608f;
    const float B = 0.044715f;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        int c = (idx / spatial_size) % C;
        float in_val = __ldg(&input[idx]);
        float m = __ldg(&multiplier[c]);
        float val = in_val * m;
        float lrelu = (val >= 0.f) ? val : val * neg_slope;
        float cube = lrelu * lrelu * lrelu;
        float gelu_val = 0.5f * lrelu * (1.f + tanhf(A * (lrelu + B * cube)));
        output[idx] = gelu_val;
    }
}

// Vectorized fused kernel using float4.
// Assumes that each channel's spatial dimension (H*W) is divisible by 4.
__global__ void fused_multiply_leakyrelu_gelu_kernel_vec(const float4* __restrict__ input,
                                                         const float* __restrict__ multiplier,
                                                         float4* __restrict__ output,
                                                         int num_vec, 
                                                         int vec_per_channel, // equals (H*W)/4
                                                         int C,
                                                         float neg_slope) {
    const float A = 0.7978845608f;
    const float B = 0.044715f;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < num_vec) {
        // Each channel has vec_per_channel groups of 4 floats.
        int c = (idx / vec_per_channel) % C;
        float m = __ldg(&multiplier[c]);
        float4 in_val = input[idx];
        float vals[4];
        vals[0] = in_val.x * m;
        vals[1] = in_val.y * m;
        vals[2] = in_val.z * m;
        vals[3] = in_val.w * m;
        #pragma unroll
        for (int i = 0; i < 4; i++) {
            vals[i] = (vals[i] >= 0.f) ? vals[i] : vals[i] * neg_slope;
            float cube = vals[i] * vals[i] * vals[i];
            vals[i] = 0.5f * vals[i] * (1.f + tanhf(A * (vals[i] + B * cube)));
        }
        float4 out_val = {vals[0], vals[1], vals[2], vals[3]};
        output[idx] = out_val;
    }
}

// Fused operation: selects the vectorized or scalar kernel based on the spatial size.
torch::Tensor fused_op(torch::Tensor input, torch::Tensor multiplier, float neg_slope) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    TORCH_CHECK(input.dim() == 4, "Input tensor must be 4D (N, C, H, W)");
    int N = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    int spatial_size = H * W;
    int total = input.numel();
    
    auto output = torch::empty_like(input);
    int threads = 256;
    
    if ((spatial_size % 4) == 0) {
        int vec_per_channel = spatial_size / 4;
        int num_vec = N * C * vec_per_channel;
        int blocks = (num_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        fused_multiply_leakyrelu_gelu_kernel_vec<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec,
            multiplier.data_ptr<float>(),
            output_vec,
            num_vec,
            vec_per_channel,
            C,
            neg_slope
        );
    } else {
        int blocks = (total + threads - 1) / threads;
        fused_multiply_leakyrelu_gelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            multiplier.data_ptr<float>(),
            output.data_ptr<float>(),
            total,
            spatial_size,
            C,
            neg_slope
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_op", &fused_op, "Fused multiply by scalar, LeakyReLU, and GELU (with vectorized and scalar kernels)");
}