1.26. GELU Activation Computation

Calculates the Gaussian Error Linear Unit (GELU) activation function on the input data, transforming it elementwise using the tanh-based approximation as described in the GELU paper. The output retains the same shape as the input.

Progress

Step 1 of 1
Reward
1.15x

Optimization Timeline

Step 1 Details

Reasoning

We need to implement a custom CUDA operator that applies the GELU activation function element‐wise on an input tensor. The approximate GELU function is defined as:   gelu(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715*x³))). Our goal is to launch a CUDA kernel where each thread processes one or more elements in a grid-stride loop. We use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() to ensure proper device management and launch the kernel on the correct stream. In the Python module, we load the CUDA extension compiled with flags “-O3” and “--use_fast_math”, then wrap the kernel call in a PyTorch module named CUDAModel whose forward method operates on an input tensor and returns the GELU-activated output. Before we finish, we note that our optimizations (grid-stride loops, using intrinsic math functions, fast math, and proper device management) should help speed up computation compared to a native PyTorch call to F.gelu().

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// The GELU activation function using the approximate formulation.
__device__ inline float gelu_approx(float x) {
    const float c = 0.044715f;
    const float k = 0.7978845608028654f; // sqrt(2.0f / M_PI)
    float x_cubed = x * x * x;
    float tanh_out = tanhf(k * (x + c * x_cubed));
    return 0.5f * x * (1.0f + tanh_out);
}

// CUDA kernel to apply GELU activation element-wise.
__global__ void gelu_kernel(const float* __restrict__ input,
                            float* __restrict__ output,
                            int numel) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    for (; idx < numel; idx += stride) {
        float x = input[idx];
        output[idx] = gelu_approx(x);
    }
}

torch::Tensor gelu(torch::Tensor input) {
    // Ensure the tensor is on CUDA.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Flatten the input to a 1D array.
    int numel = input.numel();
    
    auto output = torch::empty_like(input);
    
    int threads = 256;
    int blocks = (numel + threads - 1) / threads;
    
    gelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        numel
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("gelu", &gelu, "Apply the GELU activation (approximate) using CUDA");
}