1.40. Layer Normalization Computation

Calculates layer normalization by computing the mean and variance over a specified shape of the input tensor and then normalizing the values accordingly.

Progress

Step 2 of 2
Reward
6.62x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize Layer Normalization by vectorizing the memory accesses when the number of elements (M, the product of normalized dimensions) is divisible by 4. In our previous fused implementation, each sample (a “batch”) is processed by one block and each thread loops over a subset of the sample’s elements sequentially. In many practical cases (e.g. with normalized_shape = (64, 256, 256)), M is huge and divisible by 4. In these cases it is beneficial to reinterpret the input and output as arrays of float4. Each thread then loads four floats at a time and computes a partial sum and sum-of-squares. After the reduction in shared memory is performed over the partial sums, the mean and variance are computed and then each thread normalizes its four values. This reduces the number of memory transactions and improves performance. In the host “forward” function we check if M % 4 == 0. If so, we launch our vectorized kernel; otherwise we fallback to the scalar kernel. We continue using c10::cuda::CUDAGuard and launch kernels on c10::cuda::getCurrentCUDAStream(). Our Python module remains almost identical since it only calls the “forward” function from the CUDA extension. Before finishing, I ask: Have my improvements genuinely helped? Yes—the vectorized load and store should improve memory throughput dramatically when M is divisible by 4.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Scalar (non-vectorized) kernel for Layer Normalization.
// Each block processes one sample.
__global__ void layer_norm_kernel(const float* __restrict__ input,
                                  float* __restrict__ output,
                                  int M, float eps) {
    int sample = blockIdx.x;           // index of sample (batch dimension)
    int tid = threadIdx.x;
    int num_threads = blockDim.x;
    
    // Pointers for the current sample.
    const float* sample_in = input + sample * M;
    float* sample_out = output + sample * M;
    
    float sum = 0.f;
    float sumsq = 0.f;
    // Each thread processes its subset.
    for (int i = tid; i < M; i += num_threads) {
        float val = sample_in[i];
        sum += val;
        sumsq += val * val;
    }
    
    // Shared memory for reduction.
    extern __shared__ float shared[];
    float* s_sum = shared;               // size: num_threads
    float* s_sumsq = shared + num_threads; // size: num_threads
    
    s_sum[tid] = sum;
    s_sumsq[tid] = sumsq;
    __syncthreads();
    
    // Parallel reduction.
    for (int s = num_threads / 2; s > 0; s >>= 1) {
        if (tid < s) {
            s_sum[tid] += s_sum[tid + s];
            s_sumsq[tid] += s_sumsq[tid + s];
        }
        __syncthreads();
    }
    
    float mean = s_sum[0] / M;
    float var = s_sumsq[0] / M - mean * mean;
    float inv_std = rsqrtf(var + eps);
    
    // Normalize.
    for (int i = tid; i < M; i += num_threads) {
        sample_out[i] = (sample_in[i] - mean) * inv_std;
    }
}

// Vectorized kernel for Layer Normalization when M is divisible by 4.
// The input/output are processed as float4 arrays.
__global__ void vectorized_layer_norm_kernel(const float4* __restrict__ input,
                                             float4* __restrict__ output,
                                             int M, float eps) {
    // M is total number of float elements. Number of float4 elements is M_vec = M/4.
    int M_vec = M / 4;
    int sample = blockIdx.x;  // one sample per block
    int tid = threadIdx.x;
    int num_threads = blockDim.x;

    // Set pointers for vectorized access.
    const float4* sample_in = input + sample * M_vec;
    float4* sample_out = output + sample * M_vec;
    
    float partial_sum = 0.f;
    float partial_sumsq = 0.f;
    for (int i = tid; i < M_vec; i += num_threads) {
        float4 v = sample_in[i];
        float sum_val = v.x + v.y + v.z + v.w;
        float sumsq_val = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
        partial_sum += sum_val;
        partial_sumsq += sumsq_val;
    }
    
    extern __shared__ float shared[];
    float* s_sum = shared;               // size: num_threads
    float* s_sumsq = shared + num_threads; // size: num_threads
    
    s_sum[tid] = partial_sum;
    s_sumsq[tid] = partial_sumsq;
    __syncthreads();
    
    for (int s = num_threads / 2; s > 0; s >>= 1) {
        if (tid < s) {
            s_sum[tid] += s_sum[tid + s];
            s_sumsq[tid] += s_sumsq[tid + s];
        }
        __syncthreads();
    }
    
    // Compute mean and variance over all M floats.
    float mean = s_sum[0] / M;
    float var = s_sumsq[0] / M - mean * mean;
    float inv_std = rsqrtf(var + eps);
    
    // Normalize vectorized.
    for (int i = tid; i < M_vec; i += num_threads) {
        float4 v = sample_in[i];
        float4 norm;
        norm.x = (v.x - mean) * inv_std;
        norm.y = (v.y - mean) * inv_std;
        norm.z = (v.z - mean) * inv_std;
        norm.w = (v.w - mean) * inv_std;
        sample_out[i] = norm;
    }
}

torch::Tensor forward(torch::Tensor input, float eps) {
    // Enforce device using CUDAGuard.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Input shape: [batch, ...]. Compute total number of elements per sample.
    int batch_size = input.size(0);
    int M = 1;
    for (int i = 1; i < input.dim(); i++) {
        M *= input.size(i);
    }
    
    // Allocate output tensor.
    auto output = torch::empty_like(input);
    
    int threads = 256;
    int blocks = batch_size;
    size_t shared_mem = threads * 2 * sizeof(float);
    
    if (M % 4 == 0) {
        // Use vectorized kernel.
        int M_vec = M / 4;
        vectorized_layer_norm_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
            reinterpret_cast<const float4*>(input.data_ptr<float>()),
            reinterpret_cast<float4*>(output.data_ptr<float>()),
            M,
            eps
        );
    } else {
        // Use scalar kernel.
        layer_norm_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            output.data_ptr<float>(),
            M,
            eps
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Layer Normalization without affine transformation (CUDA)");
}