2.56. Matmul Sigmoid Sum Operation
Performs a linear transformation on input data via matrix multiplication, applies the sigmoid activation to the result, and then sums the activated values along the feature dimension to produce a single output for each instance.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our fused kernel by reducing the overhead of block‐level reduction. Instead of doing a shared memory reduction loop over all threads, we can use warp-level reduction via __shfl_down_sync to reduce each warp’s results and then only store one value per warp in shared memory; finally a small reduction among warp leaders yields the final sum. This reduces the number of syncthreads and uses faster warp shuffle operations. Since our block size is 128 (or a power of two), we will have (blockDim.x/warpSize) warp leaders. Each thread accumulates a local sum for its assigned neurons (computing the linear transformation, applying sigmoid, and summing). Then, we reduce within each warp using __shfl_down_sync. Finally, the first warp reduces the warp-level partial sums stored in shared memory. This should further lower the latency compared to our previous shared-memory reduction. The Python module remains nearly unchanged, while the CUDA kernel is modified to use warp shuffle reduction. We'll now provide the updated code.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel: for each sample compute:
// for each neuron j: dot = x[i,:] * weight[j,:] + bias[j], then sigmoid(dot)
// and sum over j.
// Each block corresponds to one sample; each thread processes multiple neurons.
// We use warp shuffle to reduce partial sums quickly.
__global__ void fused_linear_sigmoid_sum_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ out,
int input_size,
int hidden_size) {
// Each block corresponds to one sample.
int sample = blockIdx.x;
const float* x_ptr = x + sample * input_size;
int tid = threadIdx.x;
float local_sum = 0.0f;
// Process hidden neurons in a strided manner.
for (int j = tid; j < hidden_size; j += blockDim.x) {
float dot = bias[j]; // start with the bias for neuron j.
const float* weight_row = weight + j * input_size; // weight row for neuron j.
for (int k = 0; k < input_size; k++) {
dot += x_ptr[k] * weight_row[k];
}
// Apply sigmoid activation.
float sig = 1.0f / (1.0f + expf(-dot));
local_sum += sig;
}
// Warp-level reduction using __shfl_down_sync.
unsigned int mask = 0xffffffff;
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
local_sum += __shfl_down_sync(mask, local_sum, offset);
}
// Now, the first thread of each warp writes its partial sum to shared memory.
int lane = tid & (warpSize - 1);
int warp_id = tid / warpSize;
__shared__ float warp_sum[32]; // max number of warps per block; 128/32 = 4.
if (lane == 0) {
warp_sum[warp_id] = local_sum;
}
__syncthreads();
// Final reduction across warp sums by the first warp.
float final_sum = 0.0f;
int num_warps = (blockDim.x + warpSize - 1) / warpSize;
if (tid < num_warps) {
final_sum = warp_sum[tid];
for (int offset = num_warps / 2; offset > 0; offset /= 2) {
final_sum += __shfl_down_sync(mask, final_sum, offset);
}
if (tid == 0) {
out[sample] = final_sum;
}
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int input_size, int hidden_size) {
c10::cuda::CUDAGuard device_guard(x.device());
int batch_size = x.size(0);
// Allocate output tensor of shape [batch_size].
auto output = torch::empty({batch_size}, x.options());
int threads = 128; // blockDim.x; recommended power of 2
int blocks = batch_size;
size_t shared_mem_size = (threads / 32) * sizeof(float); // one float per warp.
fused_linear_sigmoid_sum_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
input_size,
hidden_size
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused linear, sigmoid, and sum kernel with warp shuffle reduction");
}