2.28. BMMInstanceNormResidualMultiply

Calculates a linear transformation on a batch of input data via matrix multiplication, applies instance normalization by reshaping the output, then adds a residual input and performs elementwise multiplication with that same input.

Progress

Step 2 of 2
Reward
2.82x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused operator by combining the reduction (computing per‐sample mean and variance) and the subsequent normalization and fusion (addition and multiplication with the residual) into a single kernel. In our previous iteration we launched two kernels, one to compute the statistics and one to do the normalization and fusion. Here, each block processes one sample. Each block uses a single loop to accumulate partial sums (and squares) for the reduction, performs the reduction entirely in shared memory, and then immediately uses the computed mean and variance to normalize each element and apply the residual fusion. We achieve this fusion by using dynamic shared memory to hold two arrays (for sum and sum‐of‐squares) plus two extra floats for the computed mean and variance. This eliminates the overhead of an extra kernel launch and avoids writing temporary results to global memory. The potential benefits are fewer kernel launches and better memory locality, resulting in a speedup over the previous version. Before finalizing, I ask: Have the improvements genuinely helped? Yes, they reduce global memory accesses (statistics are not stored in global memory) and reduce kernel launch overhead. This is not just “cheating” but a genuine fusion of two steps. Now I will update the CUDA kernel accordingly while the Python module remains the same.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

#define THREADS_PER_BLOCK 128

// Fused kernel that computes per-sample mean and variance, then normalizes and fuses with the residual.
// Each block processes one sample.
__global__ void fused_instance_norm_kernel(const float* __restrict__ x,
                                             const float* __restrict__ y,
                                             float* __restrict__ output,
                                             int F,
                                             float eps,
                                             int batch) {
    int b = blockIdx.x;
    if (b >= batch) return;
    const float* x_sample = x + b * F;
    const float* y_sample = y + b * F;
    float* out_sample = output + b * F;
    
    int blockSize = blockDim.x;  // expected to be THREADS_PER_BLOCK

    // Allocate dynamic shared memory:
    // Two arrays for partial reductions (each of size blockSize) plus two extra floats for mean and variance.
    // Total shared memory size = (2 * blockSize + 2) * sizeof(float)
    extern __shared__ float s[];
    float* s_sum    = s;                // size: blockSize
    float* s_sum_sq = s + blockSize;      // size: blockSize
    float* s_vals   = s + 2 * blockSize;    // size: 2; s_vals[0]=mean, s_vals[1]=variance

    float sum = 0.0f;
    float sum_sq = 0.0f;
    // Each thread accumulates partial sums over its assigned indices.
    for (int j = threadIdx.x; j < F; j += blockSize) {
        float val = x_sample[j];
        sum += val;
        sum_sq += val * val;
    }
    s_sum[threadIdx.x] = sum;
    s_sum_sq[threadIdx.x] = sum_sq;
    __syncthreads();
    
    // Perform tree reduction in shared memory.
    for (int stride = blockSize / 2; stride > 0; stride /= 2) {
        if (threadIdx.x < stride) {
            s_sum[threadIdx.x] += s_sum[threadIdx.x + stride];
            s_sum_sq[threadIdx.x] += s_sum_sq[threadIdx.x + stride];
        }
        __syncthreads();
    }
    
    // Thread 0 computes the mean and variance.
    if (threadIdx.x == 0) {
        s_vals[0] = s_sum[0] / F;                       // mean
        s_vals[1] = s_sum_sq[0] / F - s_vals[0] * s_vals[0];  // variance
    }
    __syncthreads();
    
    float mean = s_vals[0];
    float var = s_vals[1];
    
    // Normalize and fuse: for each feature element, compute
    // normalized = (x - mean) / sqrt(var + eps) and then fused_val = (normalized + y) * y.
    for (int j = threadIdx.x; j < F; j += blockSize) {
        float norm = (x_sample[j] - mean) / sqrtf(var + eps);
        float fused_val = (norm + y_sample[j]) * y_sample[j];
        out_sample[j] = fused_val;
    }
}

torch::Tensor forward(torch::Tensor input, torch::Tensor y, float eps) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    TORCH_CHECK(input.dim() == 2, "Input tensor must be 2D (batch, features)");
    TORCH_CHECK(y.dim() == 2, "Residual tensor must be 2D (batch, features)");
    TORCH_CHECK(input.sizes() == y.sizes(), "Input and residual must have the same shape");
    
    int batch = input.size(0);
    int F = input.size(1);
    
    auto output = torch::empty({batch, F}, input.options());
    
    int threads = THREADS_PER_BLOCK;
    dim3 grid(batch);
    dim3 block(threads);
    // Shared memory size = (2 * threads + 2) * sizeof(float)
    size_t shared_mem_size = (2 * threads + 2) * sizeof(float);
    
    fused_instance_norm_kernel<<<grid, block, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        y.data_ptr<float>(),
        output.data_ptr<float>(),
        F,
        eps,
        batch);
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused instance norm, residual addition, and element-wise multiplication (single kernel)");
}