2.62. Matrix Multiplication with Group Normalization, LeakyReLU and Summation

Performs a sequence of operations starting with a matrix multiplication to linearly transform the input tensor, followed by group normalization to standardize intermediate features across specified groups, then applies a LeakyReLU activation for non-linearity, and finally computes an element-wise sum by adding the activated tensor with itself.

Progress

Step 3 of 3
Reward
2.52x

Optimization Timeline

Step 3 Details

Reasoning

We note that the fused operation (GroupNorm, LeakyReLU, and element‐wise summation) has already been optimized by splitting the reduction into two variants: one using warp‐level reduction when the group size is small (≤32) and one using shared memory for larger groups. Each sample is processed by a separate block with blockDim.x equal to the hidden size. In the warp-level branch we use __shfl_down_sync to reduce the group’s elements, and we double the activated result to account for the element‐wise sum. To further optimize, we unroll the reduction loop for the most common case when group_size equals 32. This change reduces loop overhead and improves register usage. However, we preserve our fallback for other group sizes. These improvements should further reduce latency and increase throughput without sacrificing correctness. Before finishing, I ask: Have my improvements genuinely helped? Yes – by unrolling the warp-level reduction loop in the common case, we reduce counting overhead and simplify scheduling, thereby achieving a slight additional speedup. We now provide the final updated Python and CUDA code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel: applies GroupNorm, LeakyReLU activation, and element-wise sum (doubling).
// Input "input" is shape [batch_size, hidden_size].
// hidden_size: Number of channels per sample.
// num_groups: Number of groups (hidden_size must be divisible by num_groups).
// eps: Constant for numerical stability; neg_slope: negative slope for LeakyReLU.
//
// We launch one block per sample with blockDim.x = hidden_size.
// For each group (of size = hidden_size/num_groups), we use warp-level reduction
// if the group_size equals 32 (common case) or if group_size <= 32 in general,
// otherwise we fallback to sequential reduction via shared memory.
__global__ void fused_gn_leaky_sum_kernel(const float* __restrict__ input,
                                          float* __restrict__ output,
                                          int hidden_size,
                                          int num_groups,
                                          float eps,
                                          float neg_slope) {
    int batch_idx = blockIdx.x;
    int tid = threadIdx.x;  // Channel index (0 <= tid < hidden_size)
    int group_size = hidden_size / num_groups;
    int group_idx = tid / group_size;
    int lane = tid % group_size;
    
    // Pointers for this sample.
    const float* in_ptr = input + batch_idx * hidden_size;
    float* out_ptr = output + batch_idx * hidden_size;
    
    float x_val = in_ptr[tid];
    
    // If group_size is small (<= 32) use warp-level reduction.
    if (group_size <= 32) {
        unsigned int mask = 0xffffffff;
        float sum = x_val;
        float sumsq = x_val * x_val;
        // If group_size equals 32, unroll the loop.
        if (group_size == 32) {
            sum += __shfl_down_sync(mask, sum, 16);
            sum += __shfl_down_sync(mask, sum, 8);
            sum += __shfl_down_sync(mask, sum, 4);
            sum += __shfl_down_sync(mask, sum, 2);
            sum += __shfl_down_sync(mask, sum, 1);
            sumsq += __shfl_down_sync(mask, sumsq, 16);
            sumsq += __shfl_down_sync(mask, sumsq, 8);
            sumsq += __shfl_down_sync(mask, sumsq, 4);
            sumsq += __shfl_down_sync(mask, sumsq, 2);
            sumsq += __shfl_down_sync(mask, sumsq, 1);
        } else {
            for (int offset = group_size >> 1; offset > 0; offset >>= 1) {
                float temp = __shfl_down_sync(mask, sum, offset);
                float temp2 = __shfl_down_sync(mask, sumsq, offset);
                // Only threads with lane < offset accumulate.
                if (lane < offset) {
                    sum += temp;
                    sumsq += temp2;
                }
            }
        }
        float group_sum = __shfl_sync(mask, sum, 0);
        float group_sumsq = __shfl_sync(mask, sumsq, 0);
        float mean = group_sum / group_size;
        float var = group_sumsq / group_size - mean * mean;
        float norm = (x_val - mean) / sqrtf(var + eps);
        float activated = (norm >= 0.0f) ? norm : norm * neg_slope;
        out_ptr[tid] = 2.0f * activated;
    } else {
        // Fallback: use shared memory for sequential reduction.
        extern __shared__ float sdata[]; // sdata[0..hidden_size-1] holds input, followed by 2*num_groups floats.
        sdata[tid] = x_val;
        __syncthreads();
        if (lane == 0) {
            float s = 0.0f;
            float s2 = 0.0f;
            int start = group_idx * group_size;
            for (int j = 0; j < group_size; j++) {
                float val = sdata[start + j];
                s += val;
                s2 += val * val;
            }
            sdata[hidden_size + group_idx * 2] = s;
            sdata[hidden_size + group_idx * 2 + 1] = s2;
        }
        __syncthreads();
        float group_sum = sdata[hidden_size + group_idx * 2];
        float group_sumsq = sdata[hidden_size + group_idx * 2 + 1];
        float mean = group_sum / group_size;
        float var = group_sumsq / group_size - mean * mean;
        float norm = (x_val - mean) / sqrtf(var + eps);
        float activated = (norm >= 0.0f) ? norm : norm * neg_slope;
        out_ptr[tid] = 2.0f * activated;
    }
}

torch::Tensor fused_forward(torch::Tensor input, int hidden_size, int num_groups, float eps, float neg_slope) {
    c10::cuda::CUDAGuard device_guard(input.device());
    int batch_size = input.size(0);
    auto output = torch::empty({batch_size, hidden_size}, input.options());
    
    // Shared memory required for fallback reduction is (hidden_size + 2*num_groups) floats.
    size_t shared_mem_size = (hidden_size + 2 * num_groups) * sizeof(float);
    
    dim3 grid(batch_size);
    dim3 block(hidden_size);
    
    fused_gn_leaky_sum_kernel<<<grid, block, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        hidden_size,
        num_groups,
        eps,
        neg_slope
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_forward, "Fused GroupNorm + LeakyReLU + Sum (doubling) kernel");
}