1.96. Smooth L1 (Huber) Loss Computation
Calculates the smooth L1 (Huber) loss by comparing prediction values with target values for regression tasks. This operation blends L1 and L2 loss characteristics to mitigate the effect of outliers.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our Smooth L1 Loss kernel by performing warp-level reduction instead of a full shared-memory reduction loop. In our updated CUDA kernel, each thread computes a local sum over a grid-stride loop, then reduces the sum within each warp using the __shfl_down_sync intrinsic. The first thread of each warp writes its result into a small shared memory array, and finally, the first warp performs a second reduction over these warp sums. This minimizes synchronization overhead and reduces the required shared memory size. The rest of the logic remains unchanged. We also verify that the final output is a scalar by squeezing in Python if necessary. Overall, our improvements genuinely help by reducing latency and making better use of warp-level operations.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Optimized CUDA kernel to compute Smooth L1 Loss with warp-level reduction.
//
// Each thread computes a local sum using a grid-stride loop:
// For each element, compute diff = predictions[i]-targets[i], then
// loss = 0.5*(diff^2) if fabs(diff) < 1 else fabs(diff) - 0.5.
// Then a warp-level reduction is performed using __shfl_down_sync.
// The first thread of each warp writes its sum to a shared array,
// and then the first warp reduces these values. Finally, the first thread
// atomically adds the block sum into the global accumulator.
__global__ void smooth_l1_loss_kernel(const float* __restrict__ predictions,
const float* __restrict__ targets,
float* __restrict__ loss_sum,
int total) {
float local_sum = 0.0f;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = idx; i < total; i += stride) {
float diff = predictions[i] - targets[i];
float abs_diff = fabsf(diff);
float loss = (abs_diff < 1.0f) ? 0.5f * diff * diff : (abs_diff - 0.5f);
local_sum += loss;
}
// Warp-level reduction.
unsigned int lane = threadIdx.x & 31; // lane index in the warp [0,31]
for (int offset = 16; offset > 0; offset /= 2) {
local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
}
// Each warp's first thread adds its sum to shared memory.
__shared__ float warp_sums[32]; // enough for up to 1024 threads per block (32 warps)
int warpId = threadIdx.x >> 5; // divide by 32
if (lane == 0) {
warp_sums[warpId] = local_sum;
}
__syncthreads();
// First warp reduces the partial sums.
float block_sum = 0.0f;
if (threadIdx.x < (blockDim.x >> 5)) {
block_sum = warp_sums[lane];
for (int offset = 16; offset > 0; offset /= 2) {
block_sum += __shfl_down_sync(0xffffffff, block_sum, offset);
}
}
// The first thread of the block adds the block sum into the global accumulator.
if (threadIdx.x == 0) {
atomicAdd(loss_sum, block_sum);
}
}
torch::Tensor forward(torch::Tensor predictions, torch::Tensor targets) {
// Ensure correct CUDA device ordering.
c10::cuda::CUDAGuard device_guard(predictions.device());
// Flatten the input tensors.
predictions = predictions.contiguous().view(-1);
targets = targets.contiguous().view(-1);
int total = predictions.size(0);
// Allocate a tensor for the global loss accumulator.
auto loss_sum = torch::zeros({1}, predictions.options());
int threads = 256;
int blocks = (total + threads - 1) / threads;
// Launch kernel with shared memory size of 32 floats.
smooth_l1_loss_kernel<<<blocks, threads, 32 * sizeof(float), c10::cuda::getCurrentCUDAStream()>>>(
predictions.data_ptr<float>(),
targets.data_ptr<float>(),
loss_sum.data_ptr<float>(),
total
);
// Compute mean loss: loss_sum / total.
auto mean_loss = loss_sum / total;
// Squeeze to return a scalar (0-dimensional tensor).
return mean_loss.squeeze();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused Smooth L1 (Huber) Loss computation with warp-level reduction");
}