1.24. LogSoftmax Activation Calculation

Calculates the logarithm of the softmax values along a specified dimension, transforming an input tensor into a normalized log-probability tensor with the same shape.

Progress

Step 1 of 1
Reward
1.09x

Optimization Timeline

Step 1 Details

Reasoning

We need to replace the standard LogSoftmax operation with a custom CUDA kernel that computes log–softmax along the given dimension (we assume dim=1, i.e. each row). The mathematical formulation is: for each row i and element j,   max_val = max_j x[i,j],   sumExp = Σ_j exp(x[i,j] – max_val),   output[i,j] = x[i,j] – max_val – log(sumExp). To do so efficiently on the GPU, we use one CUDA block per row. Each thread in the block processes a portion of the row and we perform parallel reductions in shared memory to compute the maximum and the sum of exponentials. We use c10::cuda::CUDAGuard to ensure proper device management and launch our kernel on the current CUDA stream. The Python module loads our extension with the flags "-O3" and "--use_fast_math", and the CUDAModel class’s forward method calls our fused log–softmax operator. With this fused solution, we avoid multiple kernel launches and extra memory traffic compared to the native log_softmax, thereby improving efficiency.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>

// Kernel to compute LogSoftmax for each row of the input.
// Assumes input has shape [batch, dim] and dim corresponds to log_softmax dimension.
__global__ void log_softmax_kernel(const float* __restrict__ input,
                                   float* __restrict__ output,
                                   int dim) {
    // Each block processes one row (sample)
    int row = blockIdx.x;
    int tid = threadIdx.x;
    int blockSize = blockDim.x;

    // pointers for the current row
    const float* in_row = input + row * dim;
    float* out_row = output + row * dim;

    extern __shared__ float sdata[];
    
    // Phase 1: Compute the maximum value of the row.
    float local_max = -FLT_MAX;
    for (int j = tid; j < dim; j += blockSize) {
        local_max = fmaxf(local_max, in_row[j]);
    }
    sdata[tid] = local_max;
    __syncthreads();
    
    // Reduction for max.
    for (int s = blockSize / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
        }
        __syncthreads();
    }
    float max_val = sdata[0];
    
    // Phase 2: Compute the sum of exponentials.
    float local_sum = 0.0f;
    for (int j = tid; j < dim; j += blockSize) {
        local_sum += expf(in_row[j] - max_val);
    }
    sdata[tid] = local_sum;
    __syncthreads();
    
    // Reduction for sum.
    for (int s = blockSize / 2; s > 0; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }
    float sum_exp = sdata[0];
    float log_sum = logf(sum_exp);
    
    // Phase 3: Compute final log softmax value.
    for (int j = tid; j < dim; j += blockSize) {
        out_row[j] = in_row[j] - max_val - log_sum;
    }
}

torch::Tensor log_softmax_forward(torch::Tensor input, int dim) {
    // Enforce operations on appropriate CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());

    // Check that the selected dimension is 1.
    if (dim != 1) {
        throw std::runtime_error("This fused log softmax kernel supports only dim=1");
    }
    
    // Input must be 2D.
    TORCH_CHECK(input.dim() == 2, "Input must be a 2D tensor");
    
    int batch = input.size(0);
    int ndim = input.size(1);  // dimension along which to compute log softmax
    
    // Allocate output tensor.
    auto output = torch::empty({batch, ndim}, input.options());
    
    // Configure kernel launch: one block per row.
    int threads = 256;  // Set a fixed block size.
    int blocks = batch;
    size_t shared_mem = threads * sizeof(float);
    
    log_softmax_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        ndim
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("log_softmax_forward", &log_softmax_forward, "Fused LogSoftmax forward (CUDA)");
}