1.99. Triplet Margin Loss Computation

Calculates the Triplet Margin Loss between an anchor, a positive, and a negative input, using a specified margin. The operation compares the distance between the anchor-positive pair and the anchor-negative pair, ensuring that the latter exceeds the former by at least the given margin.

Progress

Step 3 of 3
Reward
3.89x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our Triplet Margin Loss kernel by reducing shared memory and synchronization overhead during the reduction. In the previous kernel, we used a loop-based shared memory reduction, which is effective but can be replaced by a warp shuffle-based reduction that leverages the CUDA warp primitives. Each thread computes its local partial sum over the feature dimensions then we perform an intra-warp reduction using __shfl_down_sync. Each warp then writes its result into shared memory and thread 0 further reduces the warp sums. This reduces the number of synchronizations and shared memory accesses. The rest of the algorithm remains unchanged, and the output is averaged over the batch. This improvement genuinely speeds up the reduction while maintaining correctness. --- FIX --- The reported output mismatch indicates that the reduction across warps using shuffle operations resulted in numerical differences from the reference implementation. Rather than using a warp shuffle reduction for the final aggregation of the per-warp results, we can simplify the final reduction step: after each warp computes its partial sum via shuffle reduction, we have each warp leader write its result to shared memory, and then thread 0 in the block performs a simple loop to sum over all warp results. This avoids the pitfalls of using shuffle operations across values computed in separate warps. We expect this change to yield numerically consistent results with lower differences. The remainder of the kernel remains unchanged, and the Python side does not require changes.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

#ifndef WARP_SIZE
#define WARP_SIZE 32
#endif

// Kernel to compute Triplet Margin Loss for one sample using simplified final reduction.
// Each block handles one sample. Each thread computes a partial sum over feature dimensions.
// Warp-level reduction is done using shuffle; then each warp leader writes its sum to shared memory.
// Finally, thread 0 aggregates the per-warp sums with a simple loop.
__global__ void triplet_loss_kernel(const float* __restrict__ anchor,
                                    const float* __restrict__ positive,
                                    const float* __restrict__ negative,
                                    float* __restrict__ loss_out,  // Array of per-sample losses.
                                    int batch, int D, float margin) {
    int sample = blockIdx.x;  // one block per sample.
    int tid = threadIdx.x;
    int blockSize = blockDim.x;

    // Pointers for the current sample.
    const float* anchor_ptr   = anchor + sample * D;
    const float* positive_ptr = positive + sample * D;
    const float* negative_ptr = negative + sample * D;

    float local_pos = 0.0f;
    float local_neg = 0.0f;
    
    // Each thread accumulates over a strided range in the feature dimension.
    for (int j = tid; j < D; j += blockSize) {
        float diff_pos = anchor_ptr[j] - positive_ptr[j];
        float diff_neg = anchor_ptr[j] - negative_ptr[j];
        local_pos += diff_pos * diff_pos;
        local_neg += diff_neg * diff_neg;
    }
    
    // Perform warp-level reduction using shuffle.
    unsigned int mask = 0xFFFFFFFF;
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        local_pos += __shfl_down_sync(mask, local_pos, offset);
        local_neg += __shfl_down_sync(mask, local_neg, offset);
    }
    
    // Each warp leader writes its result to shared memory.
    extern __shared__ float shared[]; 
    // Shared memory is divided into two parts: first for pos, second for neg.
    int numWarps = blockSize / WARP_SIZE;
    float* warpPos = shared;                   // size numWarps
    float* warpNeg = shared + numWarps;          // size numWarps
    
    if ((tid & (WARP_SIZE - 1)) == 0) {
        int warpId = tid / WARP_SIZE;
        warpPos[warpId] = local_pos;
        warpNeg[warpId] = local_neg;
    }
    __syncthreads();
    
    // Thread 0 aggregates the per-warp sums.
    if (tid == 0) {
        float sum_pos = 0.0f;
        float sum_neg = 0.0f;
        for (int i = 0; i < numWarps; i++) {
            sum_pos += warpPos[i];
            sum_neg += warpNeg[i];
        }
        float eps = 1e-6f;
        float d_pos = sqrtf(sum_pos + eps);
        float d_neg = sqrtf(sum_neg + eps);
        float sample_loss = fmaxf(d_pos - d_neg + margin, 0.0f);
        loss_out[sample] = sample_loss;
    }
}

// Forward function: launches one block per sample and computes the mean loss.
torch::Tensor triplet_loss(torch::Tensor anchor, torch::Tensor positive, torch::Tensor negative, float margin) {
    // Ensure operation on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(anchor.device());
    
    TORCH_CHECK(anchor.dim() == 2, "Anchor must be a 2D tensor [batch, features]");
    TORCH_CHECK(positive.dim() == 2, "Positive must be a 2D tensor [batch, features]");
    TORCH_CHECK(negative.dim() == 2, "Negative must be a 2D tensor [batch, features]");
    
    int batch = anchor.size(0);
    int D = anchor.size(1);
    
    // Allocate an array for per-sample losses.
    auto loss_array = torch::empty({batch}, anchor.options());
    
    int threads = 256;
    int blocks = batch;  // one block per sample.
    // Shared memory for per-warp sums: two arrays of size (threads / WARP_SIZE).
    int shared_mem_size = 2 * (threads / WARP_SIZE) * sizeof(float);
    
    triplet_loss_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        anchor.data_ptr<float>(),
        positive.data_ptr<float>(),
        negative.data_ptr<float>(),
        loss_array.data_ptr<float>(),
        batch,
        D,
        margin
    );
    
    // Return the mean loss over the batch.
    auto loss = loss_array.mean();
    return loss;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("triplet_loss", &triplet_loss, "Fused Triplet Margin Loss (CUDA) with simplified final reduction");
}