3.44. Transformer Block with GELU and Causal Self-Attention

Calculates a transformer block operation where input data is first normalized and then processed through a masked multi-head self-attention mechanism that computes query, key, and value projections using a dot-product scaled by the square root of the head dimension, applying a causal mask to restrict attention to preceding tokens. It further computes a non-linear transformation using a tanh-based GELU approximation followed by a two-layer feed-forward network with dropout to refine the output.

Progress

Step 2 of 2
Reward
1.14x

Optimization Timeline

Step 2 Details

Reasoning

The goal is to further optimize the fused GELU activation by vectorizing memory loads and computation. Our previous implementation computed each element individually using a scalar loop. We can improve throughput by processing four elements at a time when the total number of elements is divisible by 4 – using float4 loads and stores. This reduces the number of global memory accesses and takes advantage of coalescing. In the GPU kernel, we check if the total number of elements is divisible by 4; if so, we launch the vectorized kernel, otherwise we fall back to the scalar kernel. In the vectorized kernel, each float4 is loaded (with fields x, y, z, w), and we compute the GELU activation for each component:   y = 0.5 * x * (1.0 + tanhf( s * (x + 0.044715 * x³) )) with s = sqrt(2/pi). The Python module loads our CUDA extension using -O3 and --use_fast_math and exposes the CUDAModel class. This fused operator is beneficial because it reduces kernel launch overhead and memory traffic. The improvements are genuine because they reduce total computation and memory usage compared to the element‐wise version.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

#define THREADS_PER_BLOCK 256

// Scalar kernel: processes one element at a time.
__global__ void newgelu_kernel(const float* __restrict__ input,
                               float* __restrict__ output,
                               int total_elements) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    const float s = 0.7978845608028654f; // sqrt(2/pi)
    for (int i = idx; i < total_elements; i += blockDim.x * gridDim.x) {
        float x = __ldg(&input[i]);
        float x2 = x * x;
        float inner = x + 0.044715f * (x2 * x);
        float tanh_val = tanhf(s * inner);
        output[i] = 0.5f * x * (1.0f + tanh_val);
    }
}

// Vectorized kernel: processes four elements at a time using float4.
__global__ void newgelu_vec_kernel(const float4* __restrict__ input_vec,
                                   float4* __restrict__ output_vec,
                                   int total_vec_elements) {
    const float s = 0.7978845608028654f; // sqrt(2/pi)
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    for (int i = idx; i < total_vec_elements; i += blockDim.x * gridDim.x) {
        float4 in_val = input_vec[i];
        float a0 = in_val.x, a1 = in_val.y, a2 = in_val.z, a3 = in_val.w;

        float x2_0 = a0 * a0;
        float inner0 = a0 + 0.044715f * (x2_0 * a0);
        float y0 = 0.5f * a0 * (1.0f + tanhf(s * inner0));

        float x2_1 = a1 * a1;
        float inner1 = a1 + 0.044715f * (x2_1 * a1);
        float y1 = 0.5f * a1 * (1.0f + tanhf(s * inner1));

        float x2_2 = a2 * a2;
        float inner2 = a2 + 0.044715f * (x2_2 * a2);
        float y2 = 0.5f * a2 * (1.0f + tanhf(s * inner2));

        float x2_3 = a3 * a3;
        float inner3 = a3 + 0.044715f * (x2_3 * a3);
        float y3 = 0.5f * a3 * (1.0f + tanhf(s * inner3));

        float4 out_val;
        out_val.x = y0;
        out_val.y = y1;
        out_val.z = y2;
        out_val.w = y3;
        output_vec[i] = out_val;
    }
}

torch::Tensor newgelu_forward(torch::Tensor input) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    auto output = torch::empty_like(input);
    int total = input.numel();
    int threads = THREADS_PER_BLOCK;
    int blocks;
    
    // If total elements is divisible by 4, use the vectorized kernel.
    if (total % 4 == 0) {
        int total_vec = total / 4;
        blocks = (total_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        newgelu_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec, output_vec, total_vec);
    } else {
        blocks = (total + threads - 1) / threads;
        newgelu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(), output.data_ptr<float>(), total);
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("newgelu_forward", &newgelu_forward, "Fused NewGELU activation (CUDA, vectorized when possible)");
}