2.45. Matrix Multiplication with Sigmoid and LogSumExp
Calculates a sequence of operations that starts with a linear matrix multiplication, applies the sigmoid activation, sums the resulting values along the feature dimension, and finally computes the log-sum-exponential of the summed values to produce a scalar output.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We want to fuse a series of operations: a matrix multiplication (from a linear layer), element‐wise sigmoid, row‐wise summation, and then logsumexp over the row sums. Mathematically, for each row i: S[i] = Σ[j=0..(hidden_size-1)] sigmoid( dot(x[i], W[j]) + b[j] ) and the final output is: output = log( Σ[i=0..(batch-1)] exp( S[i] ) ). To optimize memory and compute efficiency we can perform the following in two fused kernels: 1. Kernel “fused_gemm_sigmoid_sum_kernel” computes S for every row in parallel. We assign one block per row and let threads within the block compute partial sums over the hidden_size dimension. Reduction in shared memory yields S[i]. 2. Kernel “fused_logsumexp_kernel” computes the final logsumexp over S. We first compute the maximum value in S, then compute sum_exp = Σ exp(S[i]-max), and finally output = max + log(sum_exp). We launch one block for the reduction. Since inter‐block synchronization is not allowed, we use two separate kernels. This approach fuses all operations in our custom extension and should be much faster than launching the individual PyTorch operators. We use c10::cuda::CUDAGuard and getCurrentCUDAStream to ensure proper device handling. The Python module loads the extension with "-O3" and "--use_fast_math" extra flags. The CUDAModel class stores the linear1 weight and bias as parameters and calls our fused operator in its forward() method. Before finishing, I ask: Have my improvements genuinely helped? Yes, we are reducing memory traffic and kernel launches compared to separate gemm/sigmoid/sum/logsumexp operations. Now we provide the complete Python and CUDA code. --- FIX --- The error indicates that our fused forward function returns a tensor with shape [1] (a 1D tensor with one element) while the expected output is a scalar (0D tensor, shape []). We can easily fix this by squeezing the output tensor before returning it. In the CUDA extension’s forward() function, after launching the reduction kernel and filling an output tensor of shape {1}, we call output.squeeze() to remove the singleton dimension. This produces a 0-dimensional tensor, which meets the expected output shape.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Kernel 1: Compute row-wise summation after GEMM and Sigmoid.
// For each row i in [0, batch), compute:
// S[i] = sum_{j=0}^{hidden_size-1} sigmoid( dot(x[i], W[j]) + b[j] )
//
// We launch one block per row and use blockDim.x threads per block.
__global__ void fused_gemm_sigmoid_sum_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ S,
int batch,
int input_size,
int hidden_size) {
// Each block corresponds to one row
int row = blockIdx.x;
int tid = threadIdx.x;
float partial = 0.0f;
// Loop over hidden_size in a strided manner.
for (int j = tid; j < hidden_size; j += blockDim.x) {
float dot = 0.0f;
// Compute dot product: x[row] (length input_size) dot weight[j] (length input_size)
for (int k = 0; k < input_size; k++) {
dot += x[row * input_size + k] * weight[j * input_size + k];
}
dot += bias[j];
// Sigmoid: 1 / (1 + exp(-dot))
float sig = 1.0f / (1.0f + expf(-dot));
partial += sig;
}
// Reduction within block using shared memory.
extern __shared__ float shmem[];
shmem[tid] = partial;
__syncthreads();
// Standard reduction to sum all thread partials.
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shmem[tid] += shmem[tid + s];
}
__syncthreads();
}
if (tid == 0) {
S[row] = shmem[0];
}
}
// Kernel 2: Compute logsumexp over the S vector of length batch.
// Reduction: result = max_S + log( sum(exp(S[i]-max_S)) )
__global__ void fused_logsumexp_kernel(const float* __restrict__ S,
float* __restrict__ output,
int batch) {
extern __shared__ float shmem[];
int tid = threadIdx.x;
// Load S into shared memory (pad with -INFINITY for threads >= batch).
float val = (tid < batch) ? S[tid] : -INFINITY;
shmem[tid] = val;
__syncthreads();
// Compute maximum of S.
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shmem[tid] = fmaxf(shmem[tid], shmem[tid + s]);
}
__syncthreads();
}
float max_val = shmem[0];
__syncthreads();
// Compute sum(exp(S[i]-max_val)).
float sum_exp = (tid < batch) ? expf(S[tid] - max_val) : 0.0f;
shmem[tid] = sum_exp;
__syncthreads();
for (int s = blockDim.x/2; s > 0; s >>= 1) {
if (tid < s) {
shmem[tid] += shmem[tid + s];
}
__syncthreads();
}
if (tid == 0) {
output[0] = max_val + logf(shmem[0]);
}
}
// The forward function: fuses the above two kernels.
torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias) {
// Ensure we are on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// Check dimensions.
TORCH_CHECK(x.dim() == 2, "Input x must be 2D: [batch, input_size]");
TORCH_CHECK(weight.dim() == 2, "Weight must be 2D: [hidden_size, input_size]");
TORCH_CHECK(bias.dim() == 1, "Bias must be 1D: [hidden_size]");
int batch = x.size(0);
int input_size = x.size(1);
int hidden_size = weight.size(0);
// Allocate temporary tensor S of shape [batch] to store row-wise sums.
auto options = x.options();
auto S = torch::empty({batch}, options);
// Launch Kernel 1: Use a block per row.
int threads_per_block = 128;
dim3 grid(batch);
dim3 block(threads_per_block);
size_t shared_mem_bytes = threads_per_block * sizeof(float);
fused_gemm_sigmoid_sum_kernel<<<grid, block, shared_mem_bytes, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
S.data_ptr<float>(),
batch,
input_size,
hidden_size
);
// Launch Kernel 2: Reduction for logsumexp.
// Determine block dimension as a power of two >= batch.
int threads = 1;
while (threads < batch) threads *= 2;
dim3 grid2(1);
dim3 block2(threads);
size_t shmem2 = threads * sizeof(float);
auto output = torch::empty({1}, options);
fused_logsumexp_kernel<<<grid2, block2, shmem2, c10::cuda::getCurrentCUDAStream()>>>(
S.data_ptr<float>(),
output.data_ptr<float>(),
batch
);
// Squeeze the output to return a 0-dimensional tensor.
return output.squeeze();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused GEMM (linear1), sigmoid, row-sum, and logsumexp");
}