1.26. GELU Activation Computation
Calculates the Gaussian Error Linear Unit (GELU) activation function on the input data, transforming it elementwise using the tanh-based approximation as described in the GELU paper. The output retains the same shape as the input.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We need to implement a custom CUDA operator that applies the GELU activation function element‐wise on an input tensor. The approximate GELU function is defined as: gelu(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715*x³))). Our goal is to launch a CUDA kernel where each thread processes one or more elements in a grid-stride loop. We use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() to ensure proper device management and launch the kernel on the correct stream. In the Python module, we load the CUDA extension compiled with flags “-O3” and “--use_fast_math”, then wrap the kernel call in a PyTorch module named CUDAModel whose forward method operates on an input tensor and returns the GELU-activated output. Before we finish, we note that our optimizations (grid-stride loops, using intrinsic math functions, fast math, and proper device management) should help speed up computation compared to a native PyTorch call to F.gelu().
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// The GELU activation function using the approximate formulation.
__device__ inline float gelu_approx(float x) {
const float c = 0.044715f;
const float k = 0.7978845608028654f; // sqrt(2.0f / M_PI)
float x_cubed = x * x * x;
float tanh_out = tanhf(k * (x + c * x_cubed));
return 0.5f * x * (1.0f + tanh_out);
}
// CUDA kernel to apply GELU activation element-wise.
__global__ void gelu_kernel(const float* __restrict__ input,
float* __restrict__ output,
int numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; idx < numel; idx += stride) {
float x = input[idx];
output[idx] = gelu_approx(x);
}
}
torch::Tensor gelu(torch::Tensor input) {
// Ensure the tensor is on CUDA.
c10::cuda::CUDAGuard device_guard(input.device());
// Flatten the input to a 1D array.
int numel = input.numel();
auto output = torch::empty_like(input);
int threads = 256;
int blocks = (numel + threads - 1) / threads;
gelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
numel
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gelu", &gelu, "Apply the GELU activation (approximate) using CUDA");
}