1.100. Hinge Loss Computation
Calculates the hinge loss for binary classification by computing max(0, 1 - predictions * targets) for each sample and averaging the result over the batch.
Progress
Optimization Timeline
Step 3 Details
Reasoning
Our current implementation already achieves a significant speedup by fusing multiple operations into two kernel launches. However, we can improve performance further by merging the final reduction and division steps into a single kernel. In our previous version, if there were multiple reduction blocks, we launched two separate kernels – one to reduce the block sums and another to divide by the total number of elements. This introduces extra kernel launch overhead and synchronization. By combining these steps into a single kernel (final_reduce_and_divide_kernel), we avoid copying data to the host and eliminate an extra kernel launch. This improvement reduces latency and maintains all computation on the GPU, which is more efficient. Our updated CUDA kernel now computes the final loss (a 0-dimensional scalar) entirely on the device. Before finishing, I ask: Have my improvements genuinely helped? Yes, merging the two kernels reduces overhead and further speeds up the fused hinge loss computation, improving performance without compromising correctness.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Kernel to compute partial sums of hinge loss over input elements.
// Each thread computes a partial sum for a chunk of elements; results are reduced in shared memory.
__global__ void hinge_loss_reduce_kernel(const float* __restrict__ predictions,
const float* __restrict__ targets,
float* __restrict__ block_sums,
int N) {
extern __shared__ float sdata[];
int tid = threadIdx.x;
int index = blockIdx.x * blockDim.x * 2 + tid;
int gridSize = blockDim.x * 2 * gridDim.x;
float sum = 0.0f;
while (index < N) {
float val1 = 1.0f - predictions[index] * targets[index];
val1 = (val1 > 0.0f) ? val1 : 0.0f;
sum += val1;
int index2 = index + blockDim.x;
if (index2 < N) {
float val2 = 1.0f - predictions[index2] * targets[index2];
val2 = (val2 > 0.0f) ? val2 : 0.0f;
sum += val2;
}
index += gridSize;
}
sdata[tid] = sum;
__syncthreads();
// Reduce in shared memory.
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
block_sums[blockIdx.x] = sdata[0];
}
}
// Combined kernel that final reduces the block sums and divides by N to compute the mean loss.
__global__ void final_reduce_and_divide_kernel(const float* __restrict__ block_sums,
int numBlocks,
int N,
float* __restrict__ loss) {
extern __shared__ float sdata[];
int tid = threadIdx.x;
float sum = 0.0f;
for (int i = tid; i < numBlocks; i += blockDim.x) {
sum += block_sums[i];
}
sdata[tid] = sum;
__syncthreads();
// Reduce within shared memory.
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
*loss = sdata[0] / float(N);
}
}
torch::Tensor hinge_loss(torch::Tensor predictions, torch::Tensor targets) {
// Ensure operations run on the correct GPU.
c10::cuda::CUDAGuard device_guard(predictions.device());
// Check that predictions and targets have the same number of elements.
TORCH_CHECK(predictions.numel() == targets.numel(), "Predictions and targets must have the same number of elements");
int N = predictions.numel();
// Allocate temporary tensor for block partial sums.
int threads = 256;
int blocks = (N + threads * 2 - 1) / (threads * 2);
auto block_sums = torch::empty({blocks}, predictions.options());
size_t sharedMemSize = threads * sizeof(float);
hinge_loss_reduce_kernel<<<blocks, threads, sharedMemSize, c10::cuda::getCurrentCUDAStream()>>>(
predictions.data_ptr<float>(),
targets.data_ptr<float>(),
block_sums.data_ptr<float>(),
N
);
// Allocate a 0-dimensional tensor for the final loss.
auto loss = torch::empty({}, predictions.options());
int finalThreads = 256;
size_t finalSharedMem = finalThreads * sizeof(float);
final_reduce_and_divide_kernel<<<1, finalThreads, finalSharedMem, c10::cuda::getCurrentCUDAStream()>>>(
block_sums.data_ptr<float>(),
blocks,
N,
loss.data_ptr<float>()
);
return loss;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("hinge_loss", &hinge_loss, "Fused hinge loss (mean(max(0, 1 - predictions * targets)))");
}