2.56. Matmul Sigmoid Sum Operation

Performs a linear transformation on input data via matrix multiplication, applies the sigmoid activation to the result, and then sums the activated values along the feature dimension to produce a single output for each instance.

Progress

Step 2 of 2
Reward
3.56x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize our fused kernel by reducing the overhead of block‐level reduction. Instead of doing a shared memory reduction loop over all threads, we can use warp-level reduction via __shfl_down_sync to reduce each warp’s results and then only store one value per warp in shared memory; finally a small reduction among warp leaders yields the final sum. This reduces the number of syncthreads and uses faster warp shuffle operations. Since our block size is 128 (or a power of two), we will have (blockDim.x/warpSize) warp leaders. Each thread accumulates a local sum for its assigned neurons (computing the linear transformation, applying sigmoid, and summing). Then, we reduce within each warp using __shfl_down_sync. Finally, the first warp reduces the warp-level partial sums stored in shared memory. This should further lower the latency compared to our previous shared-memory reduction. The Python module remains nearly unchanged, while the CUDA kernel is modified to use warp shuffle reduction. We'll now provide the updated code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel: for each sample compute:
//   for each neuron j: dot = x[i,:] * weight[j,:] + bias[j], then sigmoid(dot)
// and sum over j.
// Each block corresponds to one sample; each thread processes multiple neurons.
// We use warp shuffle to reduce partial sums quickly.
__global__ void fused_linear_sigmoid_sum_kernel(const float* __restrict__ x,
                                                const float* __restrict__ weight,
                                                const float* __restrict__ bias,
                                                float* __restrict__ out,
                                                int input_size,
                                                int hidden_size) {
    // Each block corresponds to one sample.
    int sample = blockIdx.x;
    const float* x_ptr = x + sample * input_size;
    int tid = threadIdx.x;
    float local_sum = 0.0f;

    // Process hidden neurons in a strided manner.
    for (int j = tid; j < hidden_size; j += blockDim.x) {
        float dot = bias[j];  // start with the bias for neuron j.
        const float* weight_row = weight + j * input_size;  // weight row for neuron j.
        for (int k = 0; k < input_size; k++) {
            dot += x_ptr[k] * weight_row[k];
        }
        // Apply sigmoid activation.
        float sig = 1.0f / (1.0f + expf(-dot));
        local_sum += sig;
    }

    // Warp-level reduction using __shfl_down_sync.
    unsigned int mask = 0xffffffff;
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        local_sum += __shfl_down_sync(mask, local_sum, offset);
    }

    // Now, the first thread of each warp writes its partial sum to shared memory.
    int lane = tid & (warpSize - 1);
    int warp_id = tid / warpSize;
    __shared__ float warp_sum[32];  // max number of warps per block; 128/32 = 4.
    if (lane == 0) {
        warp_sum[warp_id] = local_sum;
    }
    __syncthreads();

    // Final reduction across warp sums by the first warp.
    float final_sum = 0.0f;
    int num_warps = (blockDim.x + warpSize - 1) / warpSize;
    if (tid < num_warps) {
        final_sum = warp_sum[tid];
        for (int offset = num_warps / 2; offset > 0; offset /= 2) {
            final_sum += __shfl_down_sync(mask, final_sum, offset);
        }
        if (tid == 0) {
            out[sample] = final_sum;
        }
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int input_size, int hidden_size) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch_size = x.size(0);
    // Allocate output tensor of shape [batch_size].
    auto output = torch::empty({batch_size}, x.options());
    
    int threads = 128;  // blockDim.x; recommended power of 2
    int blocks = batch_size; 
    size_t shared_mem_size = (threads / 32) * sizeof(float);  // one float per warp.
    
    fused_linear_sigmoid_sum_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        input_size,
        hidden_size
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused linear, sigmoid, and sum kernel with warp shuffle reduction");
}