2.88. GEMM, GroupNorm, and Dual Swish Multiplication

Calculates a sequence of operations that begins with a general matrix multiplication and then normalizes the intermediate output by grouping features. It then applies a Swish activation, scales the result element‐wise using a learnable weight, and finally applies a second Swish activation to produce the final output.

Progress

Step 6 of 6
Reward
1.46x

Optimization Timeline

Step 6 Details

Reasoning

We can further optimize our fused kernel by using the __ldg intrinsic in the non‐vectorized kernel when loading the weight value. Although the vectorized kernel already uses __ldg to fetch the weight vector, the non‐vectorized kernel currently fetches the weight without ldg. Using __ldg here can improve cache efficiency for these random weight accesses. This is a genuine micro‐optimization and not a hack, as it properly uses the read-only cache on devices where it’s beneficial. The rest of the code is already optimized by employing inline sigmoid functions, vectorized processing when possible, and using c10::cuda utilities for stream and device management. We now provide the updated code with this extra improvement in the non‐vectorized kernel.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Forced inline fast sigmoid function.
__device__ __forceinline__ float sigmoidf(float z) {
    return 1.0f / (1.0f + __expf(-z));
}

// Non-vectorized fused elementwise kernel: each thread processes one element.
__global__ void fused_elementwise_kernel(const float* __restrict__ x,
                                           const float* __restrict__ weight,
                                           float* __restrict__ y,
                                           int total,
                                           int out_features) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    for (int i = idx; i < total; i += stride) {
        float a = x[i];
        float t = a * sigmoidf(a);
        int col = i % out_features;
        // Use __ldg to load the weight from read-only cache.
        float weight_val = __ldg(&weight[col]);
        float t2 = t * weight_val;
        y[i] = t2 * sigmoidf(t2);
    }
}

// Vectorized fused elementwise kernel using float4.
// Assumes out_features is divisible by 4.
__global__ void fused_elementwise_vec_kernel(const float4* __restrict__ x,
                                               const float4* __restrict__ weight,
                                               float4* __restrict__ y,
                                               int total_vec,
                                               int out_features) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    int vec_out = out_features / 4; // Number of float4 elements per row.
    for (int i = idx; i < total_vec; i += stride) {
        int row = i / vec_out;
        int vec_col = i % vec_out;  // corresponds to columns: vec_col*4, vec_col*4+1, ...
        
        // Load vectorized data using read-only cache.
        float4 a_vec = __ldg(&x[i]);
        
        // Compute first swish activation: t = a * sigmoid(a)
        float t0 = a_vec.x * sigmoidf(a_vec.x);
        float t1 = a_vec.y * sigmoidf(a_vec.y);
        float t2 = a_vec.z * sigmoidf(a_vec.z);
        float t3 = a_vec.w * sigmoidf(a_vec.w);
        
        // Load corresponding weight vector using read-only cache.
        const float4* weight_vec = weight;
        float4 w_vec = __ldg(&weight_vec[vec_col]);
        
        // Multiply elementwise.
        float r0 = t0 * w_vec.x;
        float r1 = t1 * w_vec.y;
        float r2 = t2 * w_vec.z;
        float r3 = t3 * w_vec.w;
        
        // Second swish activation: y = r * sigmoid(r)
        float4 res;
        res.x = r0 * sigmoidf(r0);
        res.y = r1 * sigmoidf(r1);
        res.z = r2 * sigmoidf(r2);
        res.w = r3 * sigmoidf(r3);
        y[i] = res;
    }
}

torch::Tensor fused(torch::Tensor x, torch::Tensor weight) {
    // x: [batch, out_features]
    // weight: [out_features]
    c10::cuda::CUDAGuard device_guard(x.device());
    TORCH_CHECK(x.dim() == 2, "x must be a 2D tensor");
    TORCH_CHECK(weight.dim() == 1, "weight must be a 1D tensor");
    
    int batch = x.size(0);
    int out_features = x.size(1);
    TORCH_CHECK(weight.size(0) == out_features, "weight size must equal x's second dimension");
    
    int total = batch * out_features;
    auto y = torch::empty({batch, out_features}, x.options());
    
    // Use vectorized kernel if out_features is divisible by 4.
    if (out_features % 4 == 0) {
        int total_vec = total / 4;
        const float4* x_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
        float4* y_vec = reinterpret_cast<float4*>(y.data_ptr<float>());
        const float4* weight_vec = reinterpret_cast<const float4*>(weight.data_ptr<float>());
        int threads = 256;
        int blocks = (total_vec + threads - 1) / threads;
        fused_elementwise_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            x_vec, weight_vec, y_vec, total_vec, out_features);
    } else {
        int threads = 256;
        int blocks = (total + threads - 1) / threads;
        fused_elementwise_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            x.data_ptr<float>(), weight.data_ptr<float>(), y.data_ptr<float>(), total, out_features);
    }
    
    return y;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused", &fused, "Fused elementwise operations: first swish, multiply, second swish (vectorized version included)");
}