1.94. Mean Squared Error Calculation

Calculates the mean squared error loss by computing the average of the squared differences between predictions and target values.

Progress

Step 4 of 4
Reward
1.11x

Optimization Timeline

Step 4 Details

Reasoning

We further optimize the MSE loss kernel by unrolling the grid‐stride loop in both the scalar and vectorized kernels using the #pragma unroll directive. In addition, we increase the number of threads per block (from 256 to 512) to help reduce the number of atomicAdds and improve reduction efficiency. These changes reduce loop overhead and decrease the overall number of atomic operations, further improving performance over the native version. This approach is a genuine optimization as it directly reduces runtime by improving memory throughput and reducing synchronization overhead.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Optimized scalar kernel for computing MSE loss using warp-level reduction and unrolled loop.
__global__ void mse_loss_kernel(const float* __restrict__ predictions,
                                const float* __restrict__ targets,
                                float* __restrict__ loss,
                                int num_elements) {
    float sum = 0.0f;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    // Unrolled grid-stride loop.
    #pragma unroll
    for (int i = idx; i < num_elements; i += stride) {
        float diff = predictions[i] - targets[i];
        sum += diff * diff;
    }
    
    // Warp-level reduction using shuffle intrinsics.
    unsigned int lane = threadIdx.x & (warpSize - 1);
    for (int offset = warpSize/2; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(0xffffffff, sum, offset);
    }
    
    // Use shared memory reduction: each warp's lane 0 writes its result.
    __shared__ float shared[32]; // Assuming blockDim.x <= 1024.
    int warp_id = threadIdx.x / warpSize;
    if (lane == 0) {
        shared[warp_id] = sum;
    }
    __syncthreads();
    
    // Let the first warp reduce the partial sums.
    if (threadIdx.x < blockDim.x / warpSize) {
        float block_sum = shared[lane];
        for (int offset = warpSize/2; offset > 0; offset /= 2) {
            block_sum += __shfl_down_sync(0xffffffff, block_sum, offset);
        }
        if (lane == 0) {
            atomicAdd(loss, block_sum);
        }
    }
}

// Optimized vectorized kernel for MSE loss using float4 with loop unrolling.
__global__ void mse_loss_vec_kernel(const float4* __restrict__ predictions,
                                    const float4* __restrict__ targets,
                                    float* __restrict__ loss,
                                    int num_vec) {
    float sum = 0.0f;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    // Unrolled loop over vectorized elements.
    #pragma unroll
    for (int i = idx; i < num_vec; i += stride) {
        float4 p = predictions[i];
        float4 t = targets[i];
        float d0 = p.x - t.x;
        float d1 = p.y - t.y;
        float d2 = p.z - t.z;
        float d3 = p.w - t.w;
        sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
    }
    
    // Warp-level reduction.
    unsigned int lane = threadIdx.x & (warpSize - 1);
    for (int offset = warpSize/2; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(0xffffffff, sum, offset);
    }
    
    __shared__ float shared[32];
    int warp_id = threadIdx.x / warpSize;
    if (lane == 0) {
        shared[warp_id] = sum;
    }
    __syncthreads();
    
    if (threadIdx.x < blockDim.x / warpSize) {
        float block_sum = shared[lane];
        for (int offset = warpSize/2; offset > 0; offset /= 2) {
            block_sum += __shfl_down_sync(0xffffffff, block_sum, offset);
        }
        if (lane == 0) {
            atomicAdd(loss, block_sum);
        }
    }
}

torch::Tensor mse_loss(torch::Tensor predictions, torch::Tensor targets) {
    TORCH_CHECK(predictions.sizes() == targets.sizes(), "Predictions and targets must have the same shape.");
    
    // Ensure operations execute on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(predictions.device());
    
    int num_elements = predictions.numel();
    auto loss_tensor = torch::zeros({1}, predictions.options());
    
    // Increase threads per block to 512.
    int threads = 512;
    
    // Check if we can use the vectorized kernel (when num_elements is divisible by 4).
    if (num_elements % 4 == 0) {
        int num_vec = num_elements / 4;
        int blocks = (num_vec + threads - 1) / threads;
        mse_loss_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            reinterpret_cast<const float4*>(predictions.data_ptr<float>()),
            reinterpret_cast<const float4*>(targets.data_ptr<float>()),
            loss_tensor.data_ptr<float>(),
            num_vec
        );
    } else {
        int blocks = (num_elements + threads - 1) / threads;
        int shared_mem_bytes = threads * sizeof(float);
        mse_loss_kernel<<<blocks, threads, shared_mem_bytes, c10::cuda::getCurrentCUDAStream()>>>(
            predictions.data_ptr<float>(),
            targets.data_ptr<float>(),
            loss_tensor.data_ptr<float>(),
            num_elements
        );
    }
    
    float sum_loss = loss_tensor.item<float>();
    float mse = sum_loss / num_elements;
    return torch::full({}, mse, predictions.options());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("mse_loss", &mse_loss, "Compute the mean squared error loss using an optimized CUDA kernel with warp-level reduction, vectorized loads, and loop unrolling");
}