1.29. Softplus Activation

Calculates the Softplus activation function on the input data by applying the transformation f(x) = log(1 + exp(x)) elementwise, resulting in a smooth approximation of ReLU.

Progress

Step 4 of 4
Reward
1.15x

Optimization Timeline

Step 4 Details

Reasoning

We can further improve our softplus CUDA kernel by adding a couple of micro-optimizations. In the element-wise kernel we can use __ldg() to load the input values into registers from the read-only cache, which may give a performance boost when accessing global memory. We can also mark our softplus function as __forceinline__ to encourage inlining and reduce function call overhead. In the vectorized kernel, we already process four elements at a time, so no major change is needed there. These improvements further optimize the memory access and reduce overhead without altering numerical accuracy, thus genuinely helping performance rather than being a mere workaround.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// __forceinline__ to encourage inlining for the softplus function.
__device__ __forceinline__ float softplus_func(float x) {
    return log1pf(expf(x));
}

// Element-wise kernel to compute Softplus using __ldg to load input values.
__global__ void softplus_kernel(const float* __restrict__ input,
                                float* __restrict__ output,
                                int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        // Use __ldg to load input from read-only cache.
        float x = __ldg(&input[idx]);
        output[idx] = softplus_func(x);
    }
}

// Vectorized kernel using float4 to process 4 elements at a time.
__global__ void softplus_vector_kernel(const float4* __restrict__ input,
                                         float4* __restrict__ output,
                                         int total_vec) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total_vec) {
        float4 in_vec = input[idx];
        float4 out_vec;
        out_vec.x = softplus_func(in_vec.x);
        out_vec.y = softplus_func(in_vec.y);
        out_vec.z = softplus_func(in_vec.z);
        out_vec.w = softplus_func(in_vec.w);
        output[idx] = out_vec;
    }
}

torch::Tensor forward(torch::Tensor input) {
    // Ensure operations occur on the proper CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    auto output = torch::empty_like(input);
    int total = input.numel();
    int threads = 256;
    
    // If the total number of elements is divisible by 4, use the vectorized kernel.
    if (total % 4 == 0) {
        int total_vec = total / 4;
        int blocks = (total_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        softplus_vector_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec, output_vec, total_vec);
    } else {
        int blocks = (total + threads - 1) / threads;
        softplus_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(), output.data_ptr<float>(), total);
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized Softplus activation (CUDA) with vectorization and __ldg");
}