1.98. Kullback-Leibler Divergence Calculation

Calculates the Kullback-Leibler Divergence between two probability distributions by applying a logarithmic transformation to one distribution and comparing it to the target distribution using batch-level mean reduction.

Progress

Step 4 of 4
Reward
2.36x

Optimization Timeline

Step 4 Details

Reasoning

We've already optimized the KL divergence computation using block-level reduction with warp-level shuffles to reduce atomic contention, which resulted in a roughly 2.25× speedup over the native PyTorch implementation. One further potential improvement is to unroll the grid-stride loop inside the kernel to reduce loop overhead, but since the memory bandwidth is likely the limiting factor when processing softmax outputs, the current implementation strikes a good balance between efficiency and readability. Our current solution minimizes atomic operations by having each block perform a reduction in shared memory and then uses warp-level reduction with __shfl_down_sync. This solution is mathematically correct and has already yielded a substantial speedup without resorting to approximations or trade-offs in numerical precision. Thus, we will use the current block-level plus warp-level reduction implementation. The Python module remains essentially unchanged aside from importing and calling the optimized CUDA extension. Our final answer is the following code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Optimized kernel using block-level reduction with warp-level reduction for the final stage.
// Each thread computes a local sum over a grid-stride loop and then reduces using shared memory.
// The last warp reduction is performed with __shfl_down_sync for minimal overhead.
__global__ void kl_divergence_reduction_kernel(const float* __restrict__ predictions,
                                               const float* __restrict__ targets,
                                               float* __restrict__ loss_accum,
                                               int total_elements) {
    extern __shared__ float sdata[];
    int tid = threadIdx.x;
    int idx = blockIdx.x * blockDim.x + tid;
    int stride = blockDim.x * gridDim.x;
    float local_sum = 0.0f;
    
    // Grid-stride loop: each thread accumulates partial sum.
    for (int i = idx; i < total_elements; i += stride) {
        float p = predictions[i];
        float t = targets[i];
        local_sum += t * (logf(t) - logf(p));
    }
    
    sdata[tid] = local_sum;
    __syncthreads();
    
    // Reduce shared memory in parallel until 32 threads remain.
    for (unsigned int s = blockDim.x / 2; s >= 32; s >>= 1) {
        if (tid < s) {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }
    
    // Warp-level reduction:
    if (tid < 32) {
        unsigned int mask = 0xffffffff;
        float sum_val = sdata[tid];
        sum_val += __shfl_down_sync(mask, sum_val, 16);
        sum_val += __shfl_down_sync(mask, sum_val, 8);
        sum_val += __shfl_down_sync(mask, sum_val, 4);
        sum_val += __shfl_down_sync(mask, sum_val, 2);
        sum_val += __shfl_down_sync(mask, sum_val, 1);
        if (tid == 0) {
            atomicAdd(loss_accum, sum_val);
        }
    }
}

torch::Tensor kl_divergence(torch::Tensor predictions, torch::Tensor targets) {
    // Ensure correct device in use.
    c10::cuda::CUDAGuard device_guard(predictions.device());
    
    TORCH_CHECK(predictions.sizes() == targets.sizes(), "Predictions and targets must have the same shape");
    int total_elements = predictions.numel();
    
    // Create a scalar (0-dimensional) tensor to accumulate the loss.
    auto loss_tensor = torch::zeros({}, predictions.options());
    
    int threads = 256;
    int blocks = (total_elements + threads - 1) / threads;
    
    // Launch the kernel with shared memory allocation (threads * sizeof(float)).
    kl_divergence_reduction_kernel<<<blocks, threads, threads * sizeof(float), c10::cuda::getCurrentCUDAStream()>>>(
        predictions.data_ptr<float>(),
        targets.data_ptr<float>(),
        loss_tensor.data_ptr<float>(),
        total_elements
    );
    
    int batch_size = predictions.size(0);
    loss_tensor.div_(batch_size);
    
    return loss_tensor;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("kl_divergence", &kl_divergence, "Optimized KL Divergence loss using block and warp-level reduction");
}