1.23. Softmax Activation Computation

Calculates the softmax activation on an input tensor by transforming raw scores into probabilities along the feature dimension.

Progress

Step 2 of 2
Reward
1.08x

Optimization Timeline

Step 2 Details

Reasoning

We can try to further optimize our softmax kernel using warp‐level intrinsics to reduce shared memory usage and avoid redundant global memory loads. In our revised kernel, each block processes one row. First, each thread computes a local maximum over its assigned indices; then warp‐level reductions (using __shfl_down_sync) combine these values into the block maximum. Next, each thread computes the exponentials (with the computed max subtracted for numerical stability) for the same set of indices and stores them into a local array, while also accumulating a local sum. Again, warp‐level reduction collects these per‐thread sums into a total sum. Finally, each thread writes its normalized exp value (softmax result) from the local buffer into global memory. This avoids reading the input twice and makes use of fast warp shuffles instead of storing all partial sums in shared memory. Overall, these improvements should narrow (or even reverse) the performance difference versus the native implementation. We now provide the final updated code with these optimizations. --- FIX --- The custom softmax kernel was producing inf values because our reduction across warps was not correctly aggregating the local maximum and local sums from all warps in the block. In the original version, we attempted to use __shfl_sync and shared memory in a way that only populated the first warp's lanes, causing uninitialized data from other warps to be used. To fix this, we refactor the kernel so that each warp first performs an intra-warp reduction (using __shfl_down_sync), and then the first thread of each warp writes its result into shared memory. Then a single thread (tid 0) loops over the partial results to compute the global maximum (or the global sum) and writes the result back to shared memory so that all threads can read it. This ensures that every thread ends up using a correct global maximum and sum to compute the final softmax values. Additionally, we constrict the temporary local buffer size to a safe maximum (128 elements) for the per-thread iterations, and use proper shared memory arrays for the warp results. With these corrections the kernel now produces valid outputs.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda.h>
#include <math.h>

#ifndef WARP_SIZE
#define WARP_SIZE 32
#endif

// Custom softmax kernel with corrected warp-level reduction.
__global__ void softmax_kernel(const float* __restrict__ input,
                               float* __restrict__ output,
                               int num_cols) {
    // Each block handles one row.
    int row = blockIdx.x;
    int tid = threadIdx.x;
    int blockSize = blockDim.x;
    // Assume blockSize is an exact multiple of WARP_SIZE.
    const int num_warps = blockSize / WARP_SIZE;

    // Step 1: Compute the maximum value for the row.
    float local_max = -INFINITY;
    for (int j = tid; j < num_cols; j += blockSize) {
        float val = input[row * num_cols + j];
        local_max = fmaxf(local_max, val);
    }
    // Intra-warp reduction for max using shuffle.
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        float temp = __shfl_down_sync(0xffffffff, local_max, offset);
        local_max = fmaxf(local_max, temp);
    }
    // Write the first lane of each warp to shared memory.
    __shared__ float s_max[32];
    if ((tid & (WARP_SIZE - 1)) == 0) {
        s_max[tid / WARP_SIZE] = local_max;
    }
    __syncthreads();
    float global_max;
    if (tid == 0) {
        global_max = s_max[0];
        for (int i = 1; i < num_warps; i++) {
            global_max = fmaxf(global_max, s_max[i]);
        }
        s_max[0] = global_max;
    }
    __syncthreads();
    global_max = s_max[0];

    // Step 2: Compute exponentials and local sum.
    float local_sum = 0.0f;
    // Estimate maximum number of iterations per thread; safe size 128.
    const int max_iters = 128;
    float localExp[128];
    int count = 0;
    for (int j = tid; j < num_cols; j += blockSize) {
        float val = input[row * num_cols + j];
        float exp_val = expf(val - global_max);
        localExp[count++] = exp_val;
        local_sum += exp_val;
    }
    // Intra-warp reduction for sum.
    float sum_val = local_sum;
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        float temp = __shfl_down_sync(0xffffffff, sum_val, offset);
        sum_val += temp;
    }
    __shared__ float s_sum[32];
    if ((tid & (WARP_SIZE - 1)) == 0) {
        s_sum[tid / WARP_SIZE] = sum_val;
    }
    __syncthreads();
    if (tid == 0) {
        float total = 0.0f;
        for (int i = 0; i < num_warps; i++) {
            total += s_sum[i];
        }
        s_sum[0] = total;
    }
    __syncthreads();
    float total_sum = s_sum[0];

    // Step 3: Write normalized softmax output using the locally stored exp values.
    count = 0;
    for (int j = tid; j < num_cols; j += blockSize) {
        float exp_val = localExp[count++];
        output[row * num_cols + j] = exp_val / total_sum;
    }
}

torch::Tensor softmax_forward(torch::Tensor input) {
    // Ensure operation is performed on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());

    // Input tensor should be 2D: [batch_size, num_features].
    TORCH_CHECK(input.dim() == 2, "Input tensor must be 2D: [batch_size, num_features].");
    int batch_size = input.size(0);
    int num_cols = input.size(1);

    auto output = torch::empty_like(input);

    // Use a block size of 256 threads per row.
    int threads = 256;

    // Launch kernel: one block per row.
    softmax_kernel<<<batch_size, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        num_cols
    );

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("softmax_forward", &softmax_forward, "Custom softmax forward (CUDA) with optimized warp-level reduction");
}