2.99. Matrix Multiplication with GELU and Softmax

Calculates a matrix multiplication on the input, applies the Gaussian Error Linear Unit (GELU) activation using a tanh-based approximation, and then computes the softmax function along the output dimension.

Progress

Step 5 of 5
Reward
4.44x

Optimization Timeline

Step 5 Details

Reasoning

We can further optimize the kernel’s reduction step for softmax stability by replacing the serial reductions with parallel reductions implemented entirely in shared memory. Instead of having thread 0 do a complete serial loop to compute the max and sum of exponentials, every thread collaborates via a while-loop reduction that works for arbitrary (non power‐of‐2) out_features. Each thread first computes its own activated value (via the fused dot product and GELU), then writes that to shared memory. Next, we perform a parallel reduction to compute the maximum value across the out_features, store it (and broadcast it to all threads via a __syncthreads()), then each thread computes its exponential using its local copy of its original value (held in register) and the computed maximum. Finally, we perform a parallel reduction to compute the sum of these exponentials and use that to finalize the softmax output. This approach increases parallelism within each block and yields correct results for any out_features. The Python module remains unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Maximum number of output features safety bound.
#define MAX_OUT_FEATURES 1024

// Fused GELU activation using an approximation.
__device__ inline float gelu_approx(float x) {
    const float kAlpha = 0.044715f;
    const float kBeta = 0.7978845608028654f; // sqrt(2/pi)
    return 0.5f * x * (1.0f + tanhf(kBeta * (x + kAlpha * x * x * x)));
}

// Fused kernel: One block per sample, with blockDim.x = out_features.
// We use dynamic shared memory for reductions.
// Each thread handles one output feature.
extern __shared__ float sdata[];

__global__ void fused_linear_gelu_softmax_kernel(const float* __restrict__ x,
                                                   const float* __restrict__ weight,
                                                   const float* __restrict__ bias,
                                                   float* __restrict__ y,
                                                   int in_features,
                                                   int out_features) {
    int sample = blockIdx.x;   // which sample is processed
    int tid = threadIdx.x;     // output feature index

    // Guard against threads beyond out_features.
    if (tid >= out_features)
        return;

    // Pointer to this sample's input.
    const float* x_row = x + sample * in_features;
    
    // Pointer to weight for this output feature.
    const float* w_row = weight + tid * in_features;
    
    // Compute the dot product.
    float sum = 0.0f;
    if ((in_features & 3) == 0) {
        int m = in_features >> 2;
        for (int k = 0; k < m; k++) {
            int base = k * 4;
            sum += x_row[base]       * w_row[base]
                 + x_row[base + 1]   * w_row[base + 1]
                 + x_row[base + 2]   * w_row[base + 2]
                 + x_row[base + 3]   * w_row[base + 3];
        }
    } else {
        for (int k = 0; k < in_features; k++) {
            sum += x_row[k] * w_row[k];
        }
    }
    sum += bias[tid];

    // Apply GELU activation.
    float val = gelu_approx(sum);

    // Save the original activated value in shared memory for reduction.
    sdata[tid] = val;
    __syncthreads();

    // Parallel reduction to compute maximum: use a while loop.
    int n = out_features;
    while(n > 1) {
        int stride = (n + 1) >> 1; // ceil(n/2)
        if(tid < n - stride) {
            float other = sdata[tid + stride];
            if(other > sdata[tid]) {
                sdata[tid] = other;
            }
        }
        __syncthreads();
        n = stride;
    }
    // Now sdata[0] holds the maximum value.
    float max_val = sdata[0];
    __syncthreads();

    // Each thread computes its exponential: using its original 'val' from register.
    float exp_val = expf(val - max_val);
    
    // Store the exponential value in shared memory for sum reduction.
    sdata[tid] = exp_val;
    __syncthreads();

    // Parallel reduction to compute the sum of exponentials.
    n = out_features;
    while(n > 1) {
        int stride = (n + 1) >> 1; // ceil(n/2)
        if(tid < n - stride) {
            sdata[tid] += sdata[tid + stride];
        }
        __syncthreads();
        n = stride;
    }
    float sum_exp = sdata[0];
    __syncthreads();
    
    // Write the softmax output.
    y[sample * out_features + tid] = exp_val / sum_exp;
}

torch::Tensor fused_forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias,
                            int in_features, int out_features) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch_size = x.size(0);
    auto output = torch::empty({batch_size, out_features}, x.options());
    
    // Launch one block per sample; blockDim.x = out_features.
    dim3 grid(batch_size);
    dim3 block(out_features);
    size_t sharedMemSize = out_features * sizeof(float);
    
    fused_linear_gelu_softmax_kernel<<<grid, block, sharedMemSize, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        in_features,
        out_features
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_forward", &fused_forward, "Fused Linear, GELU, and Softmax forward pass (CUDA)");
}