2.75. GEMM-GroupNorm-Min-BiasAdd Transformation
Performs a matrix multiplication on input data, applies group normalization to evenly scale groups of features, computes the minimum value along a designated axis, and then adds a bias parameter to produce the final output.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We wish to fuse the last two operations of the model – the minimum‐reduction and bias addition – into a single custom CUDA kernel. In the original model, after executing a GEMM (using nn.Linear) and Group Normalization (using nn.GroupNorm), the code computes: temp = torch.min(x, dim=1, keepdim=True)[0] then adds a learnable bias: output = temp + self.bias Since torch.min(x, dim=1, keepdim=True)[0] reduces each row (sample) of the 2D tensor x (shape [N, C]) to a scalar (shape [N,1]), the bias (of shape [1, C]) is broadcast along the channel dimension. In effect, for each sample i the final output is: output[i, j] = (min_{j}(x[i,j])) + bias[0, j] for all channels j. We can fuse this into one kernel that, for each sample i: 1. Computes the minimum value across the C channels. 2. Then writes out output[i, j] = (min value for sample i) + bias[j] for all channels j. This helps to eliminate extra intermediate kernel launches and memory roundtrips. Our plan is as follows: • In the Python module, we will use built‐in nn.Linear (for GEMM) and nn.GroupNorm and then call our fused custom CUDA operator “min_bias” on the result. Because our custom fused kernel expects a 2D tensor (N, C) for x and a bias tensor of shape (1, C), we will reshape the provided bias (whose init shape is (1, out_features, 1, 1)) to (1, out_features). • In our CUDA kernel, we launch one block per sample (N blocks) with a thread block size that is the next power of 2 greater than or equal to C (the number of channels). Each thread in a block loads one element from x, then we perform an in-block reduction using shared memory to compute the minimum value for that sample. Finally, each thread in the block writes to the corresponding output element by adding the computed minimum (broadcast over channels) to the given bias. We use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() to manage device/stream aspects, and we compile with -O3 and --use_fast_math. These improvements should reduce the overhead from launching separate operations and reduce memory bandwidth usage, thereby speeding up the overall fused computation. Before the end of my reasoning, I check: Am I genuinely optimizing? Yes – instead of launching two separate kernels for min and bias addition (with extra global memory writes/reads), we perform a single reduction per sample and then a single pass to add bias, which is efficient for moderate channel counts (e.g. 256). This is a genuine fusion that yields speedup over executing torch.min and addition separately. Now, I will provide the final Python module and CUDA kernel code. --- FIX --- The error indicates an output shape mismatch. Our fused CUDA operator currently returns a tensor of shape [batch, out_features] (i.e. [128,256]). However, based on the original operations, the expected final output shape is [1, out_features, batch, 1] (i.e. [1,256,128,1]). This occurs because the torch.min operation (with keepdim=True) on a 2D tensor [batch, out_features] yields a shape [batch,1] which, when added to a bias of shape [1, out_features,1,1] (via broadcasting rules) produces a tensor with shape [1, out_features, batch, 1]. To fix this, we can perform a simple tensor rearrangement on the output in our C++ binding function. After our kernel produces an output of shape [batch, out_features], we simply transpose it (to [out_features, batch]), then unsqueeze dimensions at the beginning and end so as to obtain [1, out_features, batch, 1]. This minimal change in our fused extension will solve the shape mismatch without impacting the fused computation. I'll update the fused_min_bias function in our CUDA kernel file accordingly.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <float.h>
// This fused kernel performs a reduction (min over channels) and adds bias.
// Input x is a 2D tensor of shape [N, C] (row-major).
// Each block processes one sample (row) and computes:
// m = min_{j=0}^{C-1} x[i,j]
// Then, for each channel j, it computes:
// output[i, j] = m + bias[j]
// Thus, the kernel writes to an output tensor of shape [N, C].
__global__ void fused_min_bias_kernel(const float* __restrict__ x,
const float* __restrict__ bias,
float* __restrict__ output,
int C) {
// Each block handles one sample (row).
int sample = blockIdx.x;
int tid = threadIdx.x;
// Allocate shared memory for the reduction.
extern __shared__ float sdata[];
// Load one element from x if within valid channel range.
float val = (tid < C) ? x[sample * C + tid] : FLT_MAX;
sdata[tid] = val;
__syncthreads();
// Perform parallel reduction (assume blockDim.x is a power of 2 and >= C).
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s && (tid + s) < C) {
sdata[tid] = fminf(sdata[tid], sdata[tid + s]);
}
__syncthreads();
}
float min_val = sdata[0];
__syncthreads();
// Broadcast the computed minimum and add bias.
if (tid < C) {
output[sample * C + tid] = min_val + bias[tid];
}
}
torch::Tensor fused_min_bias(torch::Tensor x, torch::Tensor bias) {
// Enforce device correctness.
c10::cuda::CUDAGuard device_guard(x.device());
TORCH_CHECK(x.dim() == 2, "Input x must be a 2D tensor");
TORCH_CHECK(bias.dim() == 2 && bias.size(0) == 1, "Bias must have shape [1, C]");
int N = x.size(0); // batch size
int C = x.size(1); // number of channels
// Allocate output tensor with shape [N, C].
auto output = torch::empty({N, C}, x.options());
// Choose block size: next power of 2 >= C.
int block_size = 1;
while (block_size < C) {
block_size *= 2;
}
int threads = block_size;
int blocks = N;
int shared_mem = threads * sizeof(float);
fused_min_bias_kernel<<<blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
C
);
// Rearrange the output to match expected shape:
// Currently output has shape [N, C], but expected is [1, C, N, 1].
auto result = output.transpose(0, 1).unsqueeze(0).unsqueeze(-1);
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("min_bias", &fused_min_bias, "Fused kernel for computing min reduction along channels and adding bias, with output shape rearrangement");
}