2.40. Matrix Multiplication Scaling Residual Addition

Calculates the result of a matrix multiplication on the input data, scales the resulting values by a constant factor, and then adds back the original multiplication result as a residual connection.

Progress

Step 5 of 5
Reward
5.87x

Optimization Timeline

Step 5 Details

Reasoning

We can further optimize the kernel by employing fused multiply‐add (FMA) operations using the CUDA intrinsic __fmaf_rn. By replacing the simple multiplication and addition with __fmaf_rn, we let the hardware combine the multiplication and addition into a single instruction, potentially reducing rounding errors and improving throughput. We apply this both in the vectorized loop (when in_features is divisible by 4) and in the scalar fallback. The rest of the kernel remains similar, and we keep the use of #pragma unroll for vectorized loops. The Python module has no change as it already wraps the fused operator correctly. These improvements genuinely help by reducing the total instruction count in the innermost loop and by leveraging hardware-level optimizations.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Fused kernel to perform:
//   y = x * weight^T + bias   (with weight: [out_features, in_features])
//   output = y + scaling_factor * y = (1 + scaling_factor) * y
//
// Each thread computes one element of the output matrix.
__global__ void fused_gemm_scale_kernel(const float* __restrict__ x,        // [batch, in_features]
                                        const float* __restrict__ weight,   // [out_features, in_features]
                                        const float* __restrict__ bias,     // [out_features]
                                        float* __restrict__ output,         // [batch, out_features]
                                        int batch, int in_features, int out_features,
                                        float scale) {  // scale = 1 + scaling_factor
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch * out_features;
    if (index < total) {
        int i = index / out_features; // row index of x
        int j = index % out_features; // corresponds to weight row j and bias element j
        
        float sum = 0.0f;
        // Use vectorized computation if possible.
        if ((in_features & 3) == 0) {  // in_features divisible by 4
            int vec_len = in_features >> 2; // in_features / 4
            const float4* __restrict__ x_vec = reinterpret_cast<const float4*>(x + i * in_features);
            const float4* __restrict__ w_vec = reinterpret_cast<const float4*>(weight + j * in_features);
            #pragma unroll
            for (int k = 0; k < vec_len; k++) {
                float4 xv = x_vec[k];
                float4 wv = w_vec[k];
                // Use fused multiply-add for each component.
                sum = __fmaf_rn(xv.x, wv.x, sum);
                sum = __fmaf_rn(xv.y, wv.y, sum);
                sum = __fmaf_rn(xv.z, wv.z, sum);
                sum = __fmaf_rn(xv.w, wv.w, sum);
            }
        } else {
            // Fallback to scalar loop using FMA operations.
            for (int k = 0; k < in_features; k++) {
                sum = __fmaf_rn(x[i * in_features + k], weight[j * in_features + k], sum);
            }
        }
        sum = __fmaf_rn(1.0f, bias[j], sum); // sum += bias[j] using FMA.
        output[index] = sum * scale;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, float scaling_factor) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch = x.size(0);
    int in_features = x.size(1);
    int out_features = weight.size(0);
    
    auto output = torch::empty({batch, out_features}, x.options());
    
    int total = batch * out_features;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    float scale = scaling_factor + 1.0f;
    
    fused_gemm_scale_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        batch, in_features, out_features,
        scale
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused matrix multiplication, scaling and residual addition (optimized CUDA with FMA)");
}