2.22. Matmul Scale Residual Clamp LogSumExp Mish Computation

Performs a sequence of operations on an input tensor: calculates a linear transformation via matrix multiplication, scales the result by a given factor, adds the result to itself (a residual connection), clamps the values within specified bounds, applies a log-sum-exp reduction across the feature dimension, and finally multiplies the outcome with its Mish activation.

Progress

Step 3 of 3
Reward
3.75x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize the fused CUDA kernel by reducing redundant computations. In our previous version, we compute the clamped value twice: once for the maximum reduction and then again for the sum-of-exponentials calculation. Instead, we can perform a single loop per thread where we compute the clamped value, store it in a small local array (using registers), and update the local maximum. Then, after computing the global maximum via block reduction, each thread reuses its locally stored clamped values to compute the partial sum (i.e. sum of exp(y - m)). This saves arithmetic operations and reduces the number of times we clamp and multiply by factor. This approach is effective given that each thread processes only a few elements (for instance, if hidden_size is 1024 and blockDim.x is 256, each thread processes 4 elements). We allocate a local array of a fixed maximum size to handle worst-case iterations per thread. This optimization should yield a performance boost over the native implementation. Before finishing, I ask myself: Have my improvements genuinely helped? Yes, the strategy reduces redundant computations and should improve throughput without compromising correctness. Now I'll rewrite the CUDA kernel accordingly while keeping the Python interface the same.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Optimized fused kernel: each block processes one sample of the linear output.
// The computations are as follows:
// 1. For each element j in the sample (of length hidden_size):
//    y = clamp( x[i,j] * (2 * scale_factor), clamp_min, clamp_max )
//    Each thread processes multiple indices and stores the computed values in a local array.
// 2. Compute L = logsumexp(y) = m + log(sum(exp(y - m))) via block reduction.
// 3. Compute softplus(L) = log(1 + exp(L)) and mish(L) = L * tanh(softplus(L)).
// 4. Final output = L * mish(L), stored per sample.
__global__ void fused_kernel(const float* __restrict__ input,
                             float* __restrict__ output,
                             int hidden_size,
                             float factor,   // factor = 2 * scale_factor
                             float clamp_min,
                             float clamp_max) {
    // Each block processes one sample.
    int sample = blockIdx.x;
    const float* in_sample = input + sample * hidden_size;
    
    // Shared memory for reductions.
    extern __shared__ float sdata[];
    int tid = threadIdx.x;
    
    // Determine number of iterations per thread.
    int n_iters = (hidden_size + blockDim.x - 1) / blockDim.x;
    const int MAX_ITERS = 64;  // Maximum allowed iterations per thread.
    float local_vals[MAX_ITERS];
    
    // First pass: compute clamped value for each element and find local max.
    float local_max = -INFINITY;
    for (int i = 0; i < n_iters; i++) {
        int idx = tid + i * blockDim.x;
        float y = -INFINITY;
        if (idx < hidden_size) {
            y = in_sample[idx] * factor;
            y = fminf(fmaxf(y, clamp_min), clamp_max);
        }
        local_vals[i] = y;
        local_max = fmaxf(local_max, y);
    }
    
    // Store each thread's local maximum in shared memory.
    sdata[tid] = local_max;
    __syncthreads();
    
    // Block-level reduction to get global maximum m.
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            float other = sdata[tid + s];
            sdata[tid] = fmaxf(sdata[tid], other);
        }
        __syncthreads();
    }
    float m = sdata[0];
    
    // Second pass: compute partial sum of exp(y - m) using stored clamped values.
    float local_sum = 0.0f;
    for (int i = 0; i < n_iters; i++) {
        int idx = tid + i * blockDim.x;
        if (idx < hidden_size) {
            local_sum += expf(local_vals[i] - m);
        }
    }
    sdata[tid] = local_sum;
    __syncthreads();
    
    // Reduce to obtain the total sum of exponentials.
    for (int s = blockDim.x / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }
    
    float sumExp = sdata[0];
    float L = m + logf(sumExp);
    
    // Compute softplus(L) and mish(L).
    float softplus = logf(1.0f + expf(L));
    float mish = L * tanhf(softplus);
    
    // Write the final output for this sample.
    if (tid == 0) {
        output[sample] = L * mish;
    }
}

torch::Tensor forward(torch::Tensor input, float scale_factor, float clamp_min, float clamp_max) {
    // Ensure the correct CUDA device using c10::cuda::CUDAGuard.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Input shape: [batch, hidden_size]
    int batch_size = input.size(0);
    int hidden_size = input.size(1);
    
    // Allocate output tensor with shape [batch, 1].
    auto output = torch::empty({batch_size, 1}, input.options());
    
    // Compute the scaling factor.
    float factor = 2.0f * scale_factor;
    
    // Launch one block per sample.
    int threads = (hidden_size < 256) ? hidden_size : 256;
    int blocks = batch_size;
    size_t shared_mem = threads * sizeof(float);
    
    fused_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        hidden_size,
        factor,
        clamp_min,
        clamp_max
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Optimized fused scale, residual add, clamp, logsumexp, and mish (applied after a linear layer)");
}