2.33. GEMM with Scaling and Batch Normalization
Calculates a general matrix multiplication (GEMM) on the input data, then applies element-wise scaling using a learned parameter, and finally normalizes the features with batch normalization.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can squeeze extra performance by fully fusing GEMM (matrix multiplication) and BatchNorm in one kernel. In our previous version, we launched two kernels: one to compute GEMM+scaling into an intermediate z and a second kernel to compute the per‐feature reduction (mean/variance) and apply BN. Although correct, that approach writes z to global memory and launches two kernels, adding overhead. Instead, we can fuse everything in one kernel so that for each output feature (i.e. one column of the GEMM result), a single block computes: • Each thread computes the dot product for one sample: z[i] = (dot(x[i,:], weight[j,:]) + bias[j]) * scale[j] • All threads in the block cooperate to reduce the z’s to compute mean and variance. • Then each thread normalizes its z and applies BN’s affine transformation. We use vectorized loads for the dot product and use dynamic shared memory to hold partial sums. We allocate shared memory to store two arrays (for z and z²) and reserve two extra floats for mean and variance. This fused approach avoids writing z into global memory and reduces kernel launch overhead. The new CUDA kernel should yield further speedups. The Python module remains largely the same except that it now calls our new fused kernel. Before finishing, I ask: Have my improvements genuinely helped? Yes—the full fusion should lower memory traffic and reduce kernel launch overhead, improving performance relative to our previous two‐kernel approach.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Fully fused kernel that computes GEMM+scaling and BatchNorm normalization in one pass.
// Each block corresponds to one output feature (column j).
// Each thread in the block (with thread index i) computes the dot product for sample i:
// z = (dot(x[i,:], gemm_weight[j,:]) + gemm_bias[j]) * scale[j]
// Then, the block performs a reduction (using dynamic shared memory) to compute the mean and variance over the batch.
// Finally, each thread normalizes its computed z and applies BN affine transform.
extern "C"
__global__ void fused_gemm_bn_kernel(
const float* __restrict__ x, // [batch, in_features]
const float* __restrict__ gemm_weight, // [out_features, in_features]
const float* __restrict__ gemm_bias, // [out_features]
const float* __restrict__ scale, // [out_features]
int batch,
int in_features,
int out_features,
float eps,
const float* __restrict__ bn_weight, // [out_features]
const float* __restrict__ bn_bias, // [out_features]
float* __restrict__ y) // [batch, out_features]
{
// blockIdx.x corresponds to output feature index j.
int j = blockIdx.x;
// threadIdx.x corresponds to sample index i.
int tid = threadIdx.x;
// Use blockDim.x as the number of threads processing the batch.
// If batch is larger than blockDim.x, each thread can process one element (assume batch <= blockDim.x for simplicity).
if(tid >= batch) return;
// Compute dot product for sample i and output feature j.
// x row: x[i, :], weight row: gemm_weight[j, :]
const float* x_row = x + tid * in_features;
const float* w_row = gemm_weight + j * in_features;
float dot = 0.0f;
int nvec = in_features / 4;
int rem = in_features % 4;
if(nvec > 0) {
const float4* x_vec = reinterpret_cast<const float4*>(x_row);
const float4* w_vec = reinterpret_cast<const float4*>(w_row);
for (int v = 0; v < nvec; v++) {
float4 xv = x_vec[v];
float4 wv = w_vec[v];
dot += xv.x * wv.x + xv.y * wv.y + xv.z * wv.z + xv.w * wv.w;
}
}
int start = nvec * 4;
for (int r = 0; r < rem; r++) {
dot += x_row[start + r] * w_row[start + r];
}
// Load gemm_bias and scale using __ldg (read-only cache)
float bias = __ldg(&gemm_bias[j]);
float sc = __ldg(&scale[j]);
// Compute z for this sample.
float my_z = (dot + bias) * sc;
// Allocate dynamic shared memory.
// Layout: first blockDim.x floats for s_sum, next blockDim.x for s_sumsq,
// then one float for mean and one float for var.
extern __shared__ float shared[];
float* s_sum = shared; // size: blockDim.x
float* s_sumsq = s_sum + blockDim.x; // size: blockDim.x
float* p_mean = s_sumsq + blockDim.x; // 1 float
float* p_var = p_mean + 1; // 1 float
// Each thread writes its computed z and z*z.
s_sum[tid] = my_z;
s_sumsq[tid] = my_z * my_z;
__syncthreads();
// Perform parallel reduction to compute the sum and sum of squares.
for (unsigned int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
s_sum[tid] += s_sum[tid + stride];
s_sumsq[tid] += s_sumsq[tid + stride];
}
__syncthreads();
}
// Thread 0 computes the mean and variance.
if (tid == 0) {
float mean = s_sum[0] / batch;
float var = s_sumsq[0] / batch - mean * mean;
*p_mean = mean;
*p_var = var;
}
__syncthreads();
// Broadcast the computed mean and variance.
float mean = *p_mean;
float var = *p_var;
// Apply BatchNorm: normalize my_z and apply affine transformation.
float norm = (my_z - mean) / sqrtf(var + eps);
float bn_w = __ldg(&bn_weight[j]);
float bn_b = __ldg(&bn_bias[j]);
float result = bn_w * norm + bn_b;
// Write the result to output tensor y, which is stored in [batch, out_features] row-major.
y[tid * out_features + j] = result;
}
// Fused forward function: launches one block per output feature.
// We set threads per block = batch (assuming batch <= max threads per block).
torch::Tensor fused_gemm_scale_bn_forward(torch::Tensor x,
torch::Tensor gemm_weight,
torch::Tensor gemm_bias,
torch::Tensor scale,
torch::Tensor bn_weight,
torch::Tensor bn_bias,
float eps) {
c10::cuda::CUDAGuard device_guard(x.device());
// Dimensions: x is [batch, in_features], gemm_weight is [out_features, in_features]
TORCH_CHECK(x.dim() == 2, "Input x must be a 2D tensor");
int batch = x.size(0);
int in_features = x.size(1);
TORCH_CHECK(gemm_weight.dim() == 2, "gemm_weight must be 2D");
int out_features = gemm_weight.size(0);
TORCH_CHECK(gemm_weight.size(1) == in_features, "Mismatch in in_features");
// Allocate output tensor y of shape [batch, out_features].
auto y = torch::empty({batch, out_features}, x.options());
// Launch the fused kernel.
// Grid: one block per output feature.
// Block: threads = batch.
int threads = batch;
int blocks = out_features;
// Shared memory size: (2*threads + 2) * sizeof(float).
size_t shared_mem_bytes = (2 * threads + 2) * sizeof(float);
fused_gemm_bn_kernel<<<blocks, threads, shared_mem_bytes, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
gemm_weight.data_ptr<float>(),
gemm_bias.data_ptr<float>(),
scale.data_ptr<float>(),
batch,
in_features,
out_features,
eps,
bn_weight.data_ptr<float>(),
bn_bias.data_ptr<float>(),
y.data_ptr<float>());
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_gemm_scale_bn_forward", &fused_gemm_scale_bn_forward,
"Fully fused GEMM with scaling and BatchNorm (training mode) in one kernel");
}