1.88. GELU Tanh Approximation
Calculates the Gaussian Error Linear Unit (GELU) using a tanh-based approximation, transforming the input data via a combination of linear, cubic, and tanh operations.
Progress
Optimization Timeline
Step 4 Details
Reasoning
To further optimize the GELU kernel, we can consolidate common computation into an inline device function to reduce code duplication and allow the compiler to better optimize the math. Specifically, we define an inline function that computes GELU for one scalar element using fused multiply–add (fmaf) where applicable. Next, we rewrite both the scalar and vectorized kernels to call this helper function. This makes the code cleaner and may allow the compiler to optimize the math further, reducing register pressure and improving throughput. The Python module remains nearly the same.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Precomputed constants stored in constant memory.
__constant__ float kAlpha = 0.044715f;
__constant__ float kBeta = 0.7978845608028654f; // sqrt(2/M_PI)
// Inline device function to compute GELU for a single float.
__device__ inline float gelu_func(float x) {
float x_cube = x * x * x;
// Use fmaf to fuse the multiply-add operations:
float inner = fmaf(kBeta, x, kBeta * kAlpha * x_cube); // inner = kBeta * x + (kBeta * kAlpha) * x^3
return 0.5f * x * (1.0f + tanhf(inner));
}
// Scalar kernel: each thread processes one element.
__global__ void gelu_kernel(const float* __restrict__ input,
float* __restrict__ output,
int numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
float x = input[idx];
output[idx] = gelu_func(x);
}
}
// Vectorized kernel: each thread processes a float4 (4 elements).
__global__ void gelu_kernel_vector(const float4* __restrict__ input,
float4* __restrict__ output,
int num_vec) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_vec) {
float4 v = input[idx];
float4 r;
r.x = gelu_func(v.x);
r.y = gelu_func(v.y);
r.z = gelu_func(v.z);
r.w = gelu_func(v.w);
output[idx] = r;
}
}
torch::Tensor gelu(torch::Tensor input) {
c10::cuda::CUDAGuard device_guard(input.device());
int numel = input.numel();
auto output = torch::empty_like(input);
int threads = 256;
// If number of elements is divisible by 4, use the vectorized kernel.
if (numel % 4 == 0) {
int num_vec = numel / 4;
int blocks = (num_vec + threads - 1) / threads;
const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
gelu_kernel_vector<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_vec, output_vec, num_vec);
} else {
int blocks = (numel + threads - 1) / threads;
gelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), output.data_ptr<float>(), numel);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gelu", &gelu, "Custom vectorized CUDA GELU activation");
}