1.39. L2 Normalization
Calculates L2 normalization by dividing the input tensor by its L2 norm along a specific axis, ensuring that vectors along that dimension have a unit norm while preserving the input shape.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can improve our fused L2 normalization kernel by fusing the division phase with the norm reduction phase and using vectorized loads and stores when possible. In the updated kernel, each block processes one row of the input. First, we compute the sum of squares using either scalar loads or vectorized float4 loads (if the dimension is divisible by 4, which is often the case). Then, we perform a shared‐memory reduction to compute the L2 norm. We broadcast the computed norm using a shared variable. Finally, we normalize the row elements by dividing by the norm, again using either vectorized or scalar stores. This fusion reduces kernel launch overhead and improves memory throughput. The host function checks if the dimension is divisible by 4 and passes a flag to the kernel. This genuine improvement should further accelerate the kernel over the previous two‐kernel design. Below are the updated Python and CUDA codes.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel: Each block processes one row of the input tensor (shape: [batch, dim]).
// Based on the flag `use_vectorized` (1 means use vectorized loads and stores if dim is divisible by 4),
// the kernel computes the sum of squares, reduces them to get the L2 norm, and then normalizes the row.
__global__ void l2norm_fused_kernel(const float* __restrict__ input,
float* __restrict__ output,
int dim,
int use_vectorized) {
int row = blockIdx.x; // one row per block
int tid = threadIdx.x;
extern __shared__ float sdata[]; // shared memory for partial sums
float sum = 0.0f;
if (use_vectorized && ((dim & 3) == 0)) {
// Use vectorized loads if dim is divisible by 4.
int vecCount = dim >> 2; // dim/4
const float4* in4 = reinterpret_cast<const float4*>(input + row * dim);
for (int i = tid; i < vecCount; i += blockDim.x) {
float4 v = in4[i];
sum += v.x*v.x + v.y*v.y + v.z*v.z + v.w*v.w;
}
} else {
// Scalar loop.
for (int i = tid; i < dim; i += blockDim.x) {
float v = input[row * dim + i];
sum += v * v;
}
}
sdata[tid] = sum;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
// Broadcast the norm.
__shared__ float norm_shared;
if (tid == 0) {
norm_shared = sqrtf(sdata[0]);
if (norm_shared < 1e-6f)
norm_shared = 1.0f;
}
__syncthreads();
float norm = norm_shared;
// Normalize the row.
if (use_vectorized && ((dim & 3) == 0)) {
int vecCount = dim >> 2;
const float4* in4 = reinterpret_cast<const float4*>(input + row * dim);
float4* out4 = reinterpret_cast<float4*>(output + row * dim);
for (int i = tid; i < vecCount; i += blockDim.x) {
float4 v = in4[i];
float4 res;
res.x = v.x / norm;
res.y = v.y / norm;
res.z = v.z / norm;
res.w = v.w / norm;
out4[i] = res;
}
} else {
for (int i = tid; i < dim; i += blockDim.x) {
float v = input[row * dim + i];
output[row * dim + i] = v / norm;
}
}
}
torch::Tensor l2norm_forward(torch::Tensor input) {
// Ensure the correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
// Input tensor is assumed 2D [batch, dim].
int batch = input.size(0);
int dim = input.size(1);
// Allocate output tensor.
auto output = torch::empty({batch, dim}, input.options());
int threads = 256;
dim3 grid(batch);
dim3 block(threads);
size_t shared_mem = threads * sizeof(float);
// Determine whether to use vectorized loads/stores.
int use_vectorized = (dim % 4 == 0) ? 1 : 0;
l2norm_fused_kernel<<<grid, block, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
dim,
use_vectorized);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("l2norm_forward", &l2norm_forward, "Fused L2 normalization (CUDA) with vectorized loads/stores");
}