1.24. LogSoftmax Activation Calculation
Calculates the logarithm of the softmax values along a specified dimension, transforming an input tensor into a normalized log-probability tensor with the same shape.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We need to replace the standard LogSoftmax operation with a custom CUDA kernel that computes log–softmax along the given dimension (we assume dim=1, i.e. each row). The mathematical formulation is: for each row i and element j, max_val = max_j x[i,j], sumExp = Σ_j exp(x[i,j] – max_val), output[i,j] = x[i,j] – max_val – log(sumExp). To do so efficiently on the GPU, we use one CUDA block per row. Each thread in the block processes a portion of the row and we perform parallel reductions in shared memory to compute the maximum and the sum of exponentials. We use c10::cuda::CUDAGuard to ensure proper device management and launch our kernel on the current CUDA stream. The Python module loads our extension with the flags "-O3" and "--use_fast_math", and the CUDAModel class’s forward method calls our fused log–softmax operator. With this fused solution, we avoid multiple kernel launches and extra memory traffic compared to the native log_softmax, thereby improving efficiency.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
// Kernel to compute LogSoftmax for each row of the input.
// Assumes input has shape [batch, dim] and dim corresponds to log_softmax dimension.
__global__ void log_softmax_kernel(const float* __restrict__ input,
float* __restrict__ output,
int dim) {
// Each block processes one row (sample)
int row = blockIdx.x;
int tid = threadIdx.x;
int blockSize = blockDim.x;
// pointers for the current row
const float* in_row = input + row * dim;
float* out_row = output + row * dim;
extern __shared__ float sdata[];
// Phase 1: Compute the maximum value of the row.
float local_max = -FLT_MAX;
for (int j = tid; j < dim; j += blockSize) {
local_max = fmaxf(local_max, in_row[j]);
}
sdata[tid] = local_max;
__syncthreads();
// Reduction for max.
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
}
__syncthreads();
}
float max_val = sdata[0];
// Phase 2: Compute the sum of exponentials.
float local_sum = 0.0f;
for (int j = tid; j < dim; j += blockSize) {
local_sum += expf(in_row[j] - max_val);
}
sdata[tid] = local_sum;
__syncthreads();
// Reduction for sum.
for (int s = blockSize / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
float sum_exp = sdata[0];
float log_sum = logf(sum_exp);
// Phase 3: Compute final log softmax value.
for (int j = tid; j < dim; j += blockSize) {
out_row[j] = in_row[j] - max_val - log_sum;
}
}
torch::Tensor log_softmax_forward(torch::Tensor input, int dim) {
// Enforce operations on appropriate CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
// Check that the selected dimension is 1.
if (dim != 1) {
throw std::runtime_error("This fused log softmax kernel supports only dim=1");
}
// Input must be 2D.
TORCH_CHECK(input.dim() == 2, "Input must be a 2D tensor");
int batch = input.size(0);
int ndim = input.size(1); // dimension along which to compute log softmax
// Allocate output tensor.
auto output = torch::empty({batch, ndim}, input.options());
// Configure kernel launch: one block per row.
int threads = 256; // Set a fixed block size.
int blocks = batch;
size_t shared_mem = threads * sizeof(float);
log_softmax_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
ndim
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("log_softmax_forward", &log_softmax_forward, "Fused LogSoftmax forward (CUDA)");
}