1.28. Hard Sigmoid Activation

Calculates the Hard Sigmoid function element-wise on an input tensor, mapping values to a range between 0 and 1 using a piecewise linear transformation.

Progress

Step 4 of 4
Reward
1.25x

Optimization Timeline

Step 4 Details

Reasoning

We can further optimize our HardSigmoid kernel by exploiting the CUDA intrinsic __saturatef, which clamps a float to [0, 1] using hardware support. This intrinsic is likely faster than using fminf/fmaxf comparisons. In our scalar kernel we replace the manual clamp with __saturatef(fmaf(x, 0.166667f, 0.5f)), where fmaf provides a fused multiply-add (computing (x * 1/6 + 0.5)). Similarly, in the vectorized kernel we apply __saturatef to each component of our float4 result. This micro-optimization reduces arithmetic operations and should give an improved throughput while preserving correctness. The Python code remains nearly identical.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Scalar HardSigmoid kernel using fused multiply-add and __saturatef.
// Computes f(x) = clamp((x + 3)/6, 0, 1) as f(x) = __saturatef(fmaf(x, 0.166667f, 0.5f))
__global__ void hardsigmoid_kernel(const float* __restrict__ input,
                                   float* __restrict__ output,
                                   int numel) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < numel) {
        float x = input[idx];
        float y = fmaf(x, 0.166667f, 0.5f); // Compute (x * 1/6 + 0.5)
        output[idx] = __saturatef(y);
    }
}

// Vectorized HardSigmoid kernel: processes 4 elements at a time using float4.
__global__ void hardsigmoid_kernel_vec(const float4* __restrict__ input,
                                       float4* __restrict__ output,
                                       int num_vec) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < num_vec) {
        float4 v = input[idx];
        float4 res;
        res.x = __saturatef(fmaf(v.x, 0.166667f, 0.5f));
        res.y = __saturatef(fmaf(v.y, 0.166667f, 0.5f));
        res.z = __saturatef(fmaf(v.z, 0.166667f, 0.5f));
        res.w = __saturatef(fmaf(v.w, 0.166667f, 0.5f));
        output[idx] = res;
    }
}

// The forward function dispatches the vectorized kernel if the total number of elements is divisible by 4; 
// otherwise, it falls back to the scalar kernel.
torch::Tensor forward(torch::Tensor input) {
    // Enforce the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    auto output = torch::empty_like(input);
    int numel = input.numel();
    const int threads = 256;
    
    if (numel % 4 == 0) {
        int num_vec = numel / 4;
        int blocks = (num_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        hardsigmoid_kernel_vec<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec,
            output_vec,
            num_vec
        );
    } else {
        int blocks = (numel + threads - 1) / threads;
        hardsigmoid_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            output.data_ptr<float>(),
            numel
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Apply HardSigmoid activation (CUDA) with vectorized optimization using __saturatef");
}