2.64. Matrix Multiplication and Activation Chain

Calculates a matrix multiplication on input data, applies a log-sum-exp reduction along the feature dimension, and sequentially applies two LeakyReLU activations followed by two GELU activations.

Progress

Step 6 of 6
Reward
1.02x

Optimization Timeline

Step 6 Details

Reasoning

We can try to further optimize our fast-path fused kernel by adding a few micro‐optimizations, such as marking pointer arguments as __restrict__ to help the compiler with alias-analysis and reinforcing loop unrolling. However, since our current implementation already obtains a slight speedup (reward ~1.0035, which is acceptable), these micro‐optimizations might only produce marginal improvements. Nonetheless, we add __restrict__ qualifiers to our memory pointers and add more comments for clarity, without changing the overall algorithm. Our methodology remains the same: 1. If out_features equals exactly 2*blockDim.x (typical for 512 outputs when blockDim.x is 256), use the fast path where each thread processes two outputs using vectorized loads. 2. Otherwise, fall back to a general GEMM loop. 3. After GEMM, we perform two warp‐level reductions (one for maximum and one for summing exp(s_y – m)) using __shfl_down_sync, and store the final reduction values in shared memory so all threads can see the result. 4. We then compute log-sum-exp, and apply two successive LeakyReLU and two GELU activations. Below is the final updated Python module (unchanged in structure) and the CUDA kernel with added __restrict__ qualifiers and extra comments for clarity.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

#define BLOCK_THREADS 256
#define WARP_SIZE 32

// Device function: Warp-level reduction to compute maximum.
__inline__ __device__ float warpReduceMax(float val) {
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        // Use full mask 0xffffffff since we operate on a full warp.
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
    }
    return val;
}

// Device function: Warp-level reduction to compute sum.
__inline__ __device__ float warpReduceSum(float val) {
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

// Device function for LeakyReLU activation.
__device__ inline float leaky_relu(float x, float slope) {
    return (x >= 0.0f) ? x : slope * x;
}

// Device function for GELU activation (approximation).
__device__ inline float gelu(float x) {
    const float sqrt_2_over_pi = 0.7978845608028654f;
    return 0.5f * x * (1.0f + tanhf(sqrt_2_over_pi * (x + 0.044715f * x * x * x)));
}

// Fused kernel: each block processes one sample.
// __restrict__ qualifiers added to help compiler with aliasing.
__global__ void fused_gemm_lse_act_kernel(const float* __restrict__ x,      // [batch, in_features]
                                          const float* __restrict__ weight, // [out_features, in_features]
                                          const float* __restrict__ bias,   // [out_features]
                                          float* __restrict__ output,       // [batch, 1]
                                          int in_features,
                                          int out_features) {
    int sample = blockIdx.x;
    const float* x_sample = x + sample * in_features;

    // Dynamic shared memory for GEMM results: s_y array of size out_features.
    extern __shared__ float s_y[];

    // Shared memory for final reduction results.
    __shared__ float final_m;   // Final maximum value.
    __shared__ float final_sum; // Final sum of exp(s_y - m).

    // --- GEMM Phase ---
    // Fast path: if out_features equals exactly 2 * blockDim.x, use fast computation.
    if (out_features == blockDim.x * 2) {
        int j0 = threadIdx.x; 
        int j1 = threadIdx.x + blockDim.x;
        float dot0 = 0.0f, dot1 = 0.0f;
        if ((in_features & 3) == 0) {  // if in_features is divisible by 4, use vectorized loads.
            int vec_count = in_features >> 2;  // in_features/4
            const float4* x_vec = reinterpret_cast<const float4*>(x_sample);
            const float4* w_vec0 = reinterpret_cast<const float4*>(weight + j0 * in_features);
            const float4* w_vec1 = reinterpret_cast<const float4*>(weight + j1 * in_features);
            #pragma unroll
            for (int k = 0; k < vec_count; k++) {
                float4 xv = x_vec[k];
                float4 wv0 = w_vec0[k];
                float4 wv1 = w_vec1[k];
                dot0 += xv.x * wv0.x + xv.y * wv0.y + xv.z * wv0.z + xv.w * wv0.w;
                dot1 += xv.x * wv1.x + xv.y * wv1.y + xv.z * wv1.z + xv.w * wv1.w;
            }
        } else {
            for (int k = 0; k < in_features; k++) {
                dot0 += x_sample[k] * weight[j0 * in_features + k];
                dot1 += x_sample[k] * weight[j1 * in_features + k];
            }
        }
        dot0 += bias[j0];
        dot1 += bias[j1];
        s_y[j0] = dot0;
        s_y[j1] = dot1;
    } else {
        // Fallback general GEMM loop.
        for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
            float dot = 0.0f;
            const float* w_row = weight + j * in_features;
            if ((in_features & 3) == 0) {
                int vec_count = in_features >> 2;
                const float4* x_vec = reinterpret_cast<const float4*>(x_sample);
                const float4* w_vec = reinterpret_cast<const float4*>(w_row);
                #pragma unroll
                for (int k = 0; k < vec_count; k++) {
                    float4 xv = x_vec[k];
                    float4 wv = w_vec[k];
                    dot += xv.x * wv.x + xv.y * wv.y + xv.z * wv.z + xv.w * wv.w;
                }
            } else {
                for (int k = 0; k < in_features; k++) {
                    dot += x_sample[k] * w_row[k];
                }
            }
            dot += bias[j];
            s_y[j] = dot;
        }
    }
    __syncthreads();

    // --- Maximum Reduction ---
    float thread_max = -INFINITY;
    for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
        thread_max = fmaxf(thread_max, s_y[j]);
    }
    float local_max = warpReduceMax(thread_max);
    int lane = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    __shared__ float s_warp_max[BLOCK_THREADS / WARP_SIZE];
    if (lane == 0) {
        s_warp_max[warp_id] = local_max;
    }
    __syncthreads();
    float final_max_val = -INFINITY;
    if (threadIdx.x < (blockDim.x / WARP_SIZE)) {
        final_max_val = s_warp_max[threadIdx.x];
    }
    final_max_val = warpReduceMax(final_max_val);
    if (threadIdx.x == 0) {
        final_m = final_max_val; // Store final max in shared memory.
    }
    __syncthreads();
    float m = final_m;  // Broadcast final max.

    // --- Sum Reduction ---
    float thread_sum = 0.0f;
    for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
        thread_sum += expf(s_y[j] - m);
    }
    float local_sum = warpReduceSum(thread_sum);
    __shared__ float s_warp_sum[BLOCK_THREADS / WARP_SIZE];
    if (lane == 0) {
        s_warp_sum[warp_id] = local_sum;
    }
    __syncthreads();
    float final_sum_val = 0.0f;
    if (threadIdx.x < (blockDim.x / WARP_SIZE)) {
        final_sum_val = s_warp_sum[threadIdx.x];
    }
    final_sum_val = warpReduceSum(final_sum_val);
    if (threadIdx.x == 0) {
        final_sum = final_sum_val; // Store final sum in shared memory.
    }
    __syncthreads();
    float s_final = final_sum; // Final sum value.
    
    float lse = m + logf(s_final);

    // --- Activation Chain ---
    float act = lse;
    act = leaky_relu(act, 0.01f);
    act = leaky_relu(act, 0.01f);
    act = gelu(act);
    act = gelu(act);

    // Write the final output for this sample.
    if (threadIdx.x == 0) {
        output[sample] = act;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias,
                      int in_features, int out_features) {
    c10::cuda::CUDAGuard device_guard(x.device());
    int batch_size = x.size(0);
    auto output = torch::empty({batch_size, 1}, x.options());
    int threads = BLOCK_THREADS;
    int blocks = batch_size;
    // Allocate dynamic shared memory for GEMM results: out_features floats.
    size_t shared_mem_size = out_features * sizeof(float);
    
    fused_gemm_lse_act_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        in_features,
        out_features
    );
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused GEMM, LogSumExp, double LeakyReLU, double GELU with warp-level reduction and fast GEMM path");
}