1.97. Cosine Similarity Loss Calculation

Calculates the cosine similarity between corresponding vectors in two inputs along a specified dimension and computes a loss by averaging the value of (1 - cosine similarity) across the batch.

Progress

Step 2 of 2
Reward
3.01x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the cosine similarity loss kernel by replacing the shared‐memory reduction loop with a warp‐level reduction using __shfl_down_sync. In our previous version, we used shared memory to sum partial dot products and norm sums. Using warp-level primitives reduces synchronization overhead and the number of memory accesses. Each thread computes its local partial sums for the dot product and for the squared norms, and then we perform a warp‐synchronous reduction within each warp. The warp leaders then store the results into shared memory and a final warp-level reduction is applied to produce the final sum per block. This optimized reduction reduces latency and should yield further speedup. The Python module remains unchanged because it already loads our CUDA extension with appropriate optimization flags. We simply update the CUDA kernel accordingly. Before finalizing, I ask myself: Are my improvements genuine? Yes – using warp-level reductions is a recognized best-practice to speed up reductions on the GPU, and it avoids some shared memory overhead.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Kernel to compute cosine similarity loss per sample using warp-level reduction.
// Each block processes one sample. The reduction is done with __shfl_down_sync.
__global__ void cosine_similarity_loss_kernel(const float *predictions,
                                                const float *targets,
                                                int dim,
                                                float *loss_accum) {
    // Each block processes one sample.
    int sample = blockIdx.x;
    int tid = threadIdx.x;
    
    // Pointers to the start of the current sample's vectors.
    const float *pred_sample = predictions + sample * dim;
    const float *target_sample = targets + sample * dim;
    
    float local_dot = 0.0f;
    float local_norm_pred = 0.0f;
    float local_norm_target = 0.0f;
    
    // Loop over elements in the vector in a strided manner.
    for (int i = tid; i < dim; i += blockDim.x) {
        float p = pred_sample[i];
        float t = target_sample[i];
        local_dot += p * t;
        local_norm_pred += p * p;
        local_norm_target += t * t;
    }
    
    // Perform warp-level reduction using __shfl_down_sync.
    unsigned int mask = 0xffffffff; // full warp mask for 32 threads.
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        local_dot += __shfl_down_sync(mask, local_dot, offset);
        local_norm_pred += __shfl_down_sync(mask, local_norm_pred, offset);
        local_norm_target += __shfl_down_sync(mask, local_norm_target, offset);
    }
    
    // Allocate shared memory for storing results from each warp.
    __shared__ float warp_dot[32];       // Maximum 32 warps per block.
    __shared__ float warp_norm_pred[32];
    __shared__ float warp_norm_target[32];
    
    int lane = tid & (warpSize - 1);  // lane index in the warp.
    int warpId = tid / warpSize;      // warp index.
    
    // Each warp leader stores its result in shared memory.
    if (lane == 0) {
        warp_dot[warpId] = local_dot;
        warp_norm_pred[warpId] = local_norm_pred;
        warp_norm_target[warpId] = local_norm_target;
    }
    
    __syncthreads();
    
    // Now, let the first warp combine results from each warp.
    float sum_dot = 0.0f;
    float sum_norm_pred = 0.0f;
    float sum_norm_target = 0.0f;
    int numWarps = (blockDim.x + warpSize - 1) / warpSize;
    if (tid < numWarps) {
        sum_dot = warp_dot[tid];
        sum_norm_pred = warp_norm_pred[tid];
        sum_norm_target = warp_norm_target[tid];
    }
    
    // Perform final warp-level reduction among the warp sums.
    if (tid < warpSize) {
        for (int offset = warpSize / 2; offset > 0; offset /= 2) {
            sum_dot += __shfl_down_sync(mask, sum_dot, offset);
            sum_norm_pred += __shfl_down_sync(mask, sum_norm_pred, offset);
            sum_norm_target += __shfl_down_sync(mask, sum_norm_target, offset);
        }
        // Thread 0 writes the computed loss for this sample.
        if (tid == 0) {
            float eps = 1e-8f;
            float norm_pred_val = sqrtf(sum_norm_pred);
            float norm_target_val = sqrtf(sum_norm_target);
            float cosine = sum_dot / ((norm_pred_val * norm_target_val) + eps);
            float sample_loss = 1.0f - cosine;
            atomicAdd(loss_accum, sample_loss);
        }
    }
}

// Host function wrapping the CUDA kernel.
torch::Tensor cosine_similarity_loss(torch::Tensor predictions, torch::Tensor targets) {
    // Ensure the operation is done on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(predictions.device());
    
    // Check that predictions and targets are 2D and have the same shape.
    TORCH_CHECK(predictions.dim() == 2, "predictions must be a 2D tensor");
    TORCH_CHECK(targets.dim() == 2, "targets must be a 2D tensor");
    int batch_size = predictions.size(0);
    int dim = predictions.size(1);
    TORCH_CHECK(targets.size(0) == batch_size && targets.size(1) == dim, "Mismatch in dimensions");
    
    // Allocate a 0-dimensional tensor for accumulating the loss.
    auto loss_tensor = torch::zeros({}, predictions.options());
    
    int threads = 256;
    int blocks = batch_size; // one block per sample.
    
    cosine_similarity_loss_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        predictions.data_ptr<float>(),
        targets.data_ptr<float>(),
        dim,
        loss_tensor.data_ptr<float>()
    );
    
    // Wait for the kernel to complete.
    cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream());
    
    // Calculate the mean loss by dividing by the number of samples.
    loss_tensor.div_(batch_size);
    return loss_tensor;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("cosine_similarity_loss", &cosine_similarity_loss, "Compute cosine similarity loss using a fused CUDA kernel with warp-level reduction");
}