2.21. Convolution with Bias, Scaling, Sigmoid, and Group Normalization

Calculates a series of operations on 2D input data: applies a convolution to generate feature maps, adds a learnable bias, scales the result by a learnable parameter, applies a sigmoid activation, and finally performs group normalization across the channels.

Progress

Step 8 of 8
Reward
1.66x

Optimization Timeline

Step 8 Details

Reasoning

We can further optimize the kernel by replacing the standard expf() call in our sigmoid function with the faster __expf() intrinsic, which generally yields a significant speedup on NVIDIA GPUs. This small change will improve the throughput of the sigmoid computation. Other optimizations (like loop unrolling) are already applied. This update genuinely helps by reducing per-element computation time.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

#ifndef EPS
#define EPS 1e-5f
#endif

// Device inline fast sigmoid using __expf for faster exponential.
__device__ inline float sigmoidf(float x) {
    return 1.0f / (1.0f + __expf(-x));
}

// Fused kernel: Each block processes one sample (tensor shape: [C, H, W]).
// It computes each element as:
//    y = sigmoid((x + bias[c]) * scale[c])
// Then, for each group, it computes the mean and variance via warp-level reduction with loop unrolling
// and uses these statistics to normalize y.
extern "C"
__global__ void fused_conv_post_kernel(const float* __restrict__ input,
                                       const float* __restrict__ bias,    // shape: [C]
                                       const float* __restrict__ scale,   // shape: [C]
                                       float* __restrict__ output,
                                       int C, int H, int W,
                                       int num_groups,
                                       float eps) {
    int n = blockIdx.x; // sample index
    int total = C * H * W; // total elements per sample
    const float* in = input + n * total;
    float* out = output + n * total;
    
    int T = blockDim.x; // total threads per block
    int tid = threadIdx.x;
    
    // Determine grouping parameters.
    int channels_per_group = C / num_groups;
    int group_size = channels_per_group * H * W;
    
    // Each thread accumulates local sums for each group (assume num_groups <= 16).
    float local_sum[16] = {0.0f};
    float local_sum_sq[16] = {0.0f};
    
    // First pass: compute transformed value and accumulate per-group sums.
    for (int idx = tid; idx < total; idx += T) {
        int c = idx / (H * W);            // channel index
        int g = c / channels_per_group;     // group index
        float val = (in[idx] + bias[c]) * scale[c];
        float y = sigmoidf(val);
        out[idx] = y; // store intermediate result
        local_sum[g] += y;
        local_sum_sq[g] += y * y;
    }
    
    // Perform warp-level reduction on local sums.
    unsigned int mask = 0xffffffff;
    int lane = tid & (warpSize - 1);
    int warp_id = tid >> 5; // tid / warpSize
    int num_warps = T / warpSize;
    
    float warp_sum[16];
    float warp_sum_sq[16];
    
    // Unroll loop over groups.
    #pragma unroll
    for (int g = 0; g < 16; g++) {
        if (g < num_groups) {
            float sum = local_sum[g];
            float sum_sq = local_sum_sq[g];
            // Unrolled warp reduction.
            #pragma unroll
            for (int offset = warpSize/2; offset > 0; offset /= 2) {
                sum += __shfl_down_sync(mask, sum, offset);
                sum_sq += __shfl_down_sync(mask, sum_sq, offset);
            }
            warp_sum[g] = sum;
            warp_sum_sq[g] = sum_sq;
        }
    }
    
    // Shared memory layout:
    // - warp_sums: size = (num_warps * num_groups)
    // - warp_sums_sq: size = (num_warps * num_groups)
    // - group_mean: size = (num_groups)
    // - group_invstd: size = (num_groups)
    extern __shared__ float s_data[];
    float* warp_sums = s_data;                             // size: num_warps * num_groups
    float* warp_sums_sq = s_data + num_warps * num_groups;   // size: num_warps * num_groups
    float* group_mean = s_data + 2 * num_warps * num_groups;         // size: num_groups
    float* group_invstd = s_data + 2 * num_warps * num_groups + num_groups; // size: num_groups

    // Each warp leader writes its result to shared memory.
    if (lane == 0) {
        #pragma unroll
        for (int g = 0; g < 16; g++) {
            if (g < num_groups) {
                warp_sums[warp_id * num_groups + g] = warp_sum[g];
                warp_sums_sq[warp_id * num_groups + g] = warp_sum_sq[g];
            }
        }
    }
    __syncthreads();
    
    // Final reduction across warp leaders: thread 0 combines results.
    float final_sum[16] = {0.0f};
    float final_sum_sq[16] = {0.0f};
    if (tid == 0) {
        #pragma unroll
        for (int w = 0; w < num_warps; w++) {
            #pragma unroll
            for (int g = 0; g < 16; g++) {
                if (g < num_groups) {
                    final_sum[g] += warp_sums[w * num_groups + g];
                    final_sum_sq[g] += warp_sums_sq[w * num_groups + g];
                }
            }
        }
        #pragma unroll
        for (int g = 0; g < 16; g++) {
            if (g < num_groups) {
                float mean = final_sum[g] / group_size;
                float var = final_sum_sq[g] / group_size - mean * mean;
                group_mean[g] = mean;
                group_invstd[g] = 1.0f / sqrtf(var + eps);
            }
        }
    }
    __syncthreads();
    
    // Second pass: normalize each element using the computed group statistics.
    for (int idx = tid; idx < total; idx += T) {
        int c = idx / (H * W);
        int g = c / channels_per_group;
        float y = out[idx];
        out[idx] = (y - group_mean[g]) * group_invstd[g];
    }
}

torch::Tensor fused_forward(torch::Tensor x, torch::Tensor bias, torch::Tensor scale, int num_groups, float eps) {
    c10::cuda::CUDAGuard device_guard(x.device());
    int N = x.size(0);
    int C = x.size(1);
    int H = x.size(2);
    int W = x.size(3);
    auto output = torch::empty_like(x);
    int threads = 256;
    dim3 grid(N);
    dim3 block(threads);
    int num_warps = threads / 32;
    // Total shared memory: 2 * (num_warps * num_groups) + 2 * num_groups floats.
    size_t shared_bytes = (2 * num_warps * num_groups + 2 * num_groups) * sizeof(float);
    fused_conv_post_kernel<<<grid, block, shared_bytes, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        bias.data_ptr<float>(),
        scale.data_ptr<float>(),
        output.data_ptr<float>(),
        C, H, W, num_groups, eps
    );
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_forward, "Fused bias addition, scaling, sigmoid, and group normalization with warp-level reduction and unrolling");
}