1.95. Compute Cross Entropy Loss
Calculates the cross entropy loss between predicted logits and target labels for a multi-class classification task.
Progress
Step 3 of 3
Reward
2.14x
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our block-level reduction by replacing the final stages with warp-level shuffle operations. This reduces shared memory accesses and synchronization overhead. Instead of performing the full shared memory reduction when there are fewer than 32 threads remaining, we load the partial sum into a register (v) and use __shfl_down_sync to efficiently reduce within the warp. This reduces latency and further minimizes atomic contention by ensuring that only one atomicAdd per block is issued. Our updated kernel improves overall performance with minimal code change, and our Python module remains the same.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
////////////////////////////////////////////////////////////////////////////////
// Optimized kernel with block-level reduction and warp-level shuffle reduction.
////////////////////////////////////////////////////////////////////////////////
__global__ void cross_entropy_kernel(const float* __restrict__ predictions,
const int64_t* __restrict__ targets,
float* __restrict__ loss_accum,
int batch,
int num_classes) {
extern __shared__ float sdata[];
int tid = threadIdx.x;
int i = blockIdx.x * blockDim.x + tid;
float loss_sample = 0.0f;
if (i < batch) {
int offset = i * num_classes;
float max_val = -FLT_MAX;
// Find maximum for numerical stability.
for (int j = 0; j < num_classes; j++) {
float val = predictions[offset + j];
if (val > max_val)
max_val = val;
}
// Compute the sum of exponentials.
float sum_exp = 0.0f;
for (int j = 0; j < num_classes; j++) {
sum_exp += expf(predictions[offset + j] - max_val);
}
float log_sum_exp = logf(sum_exp) + max_val;
int target = (int)targets[i];
loss_sample = -(predictions[offset + target] - log_sum_exp);
}
// Each thread writes its loss value to shared memory.
sdata[tid] = loss_sample;
__syncthreads();
// Perform reduction in shared memory until the warp level.
for (unsigned int s = blockDim.x / 2; s >= 32; s >>= 1) {
if (tid < s)
sdata[tid] += sdata[tid + s];
__syncthreads();
}
// Use warp-level shuffle reduction without __syncthreads.
if (tid < 32) {
float v = sdata[tid];
// Perform warp shuffle reduction.
for (int offset = 16; offset > 0; offset /= 2) {
v += __shfl_down_sync(0xffffffff, v, offset);
}
// First thread adds the block’s result to global loss accumulator.
if (tid == 0) {
atomicAdd(loss_accum, v);
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Host function.
////////////////////////////////////////////////////////////////////////////////
torch::Tensor forward(torch::Tensor predictions, torch::Tensor targets) {
// Ensure we are on the correct CUDA device.
c10::cuda::CUDAGuard guard(predictions.device());
TORCH_CHECK(predictions.dim() == 2, "Predictions must be 2D (batch, num_classes)");
TORCH_CHECK(targets.dim() == 1, "Targets must be 1D (batch)");
int batch = predictions.size(0);
int num_classes = predictions.size(1);
// Allocate a scalar (0-d) tensor for loss accumulator.
auto options = predictions.options();
auto loss_tensor = torch::zeros({}, options);
int threads = 256;
int blocks = (batch + threads - 1) / threads;
size_t shared_mem_size = threads * sizeof(float);
cross_entropy_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
predictions.data_ptr<float>(),
targets.data_ptr<int64_t>(),
loss_tensor.data_ptr<float>(),
batch,
num_classes
);
// Compute the average loss.
auto mean_loss = loss_tensor / batch;
return mean_loss;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized Fused Cross Entropy Loss with Warp-level Reduction (CUDA)");
}