3.49. Mamba2 Final State Computation
Splits the input sequence into fixed-size blocks and calculates a structured state space recurrence. It uses exponential segment summation for both diagonal (intra-block) outputs and decay-based intra- and inter-block state updates, ultimately returning the final computed state.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize the fused segsum exponential kernel by replacing the single‐thread sequential cumulative sum (used previously to fill shared memory) with a parallel (shared-memory) scan. Since T (the segment length) is small (typically ≤128), we can have the block’s threads cooperate to compute the cumulative sum in parallel using a simple inclusive scan. This should reduce the latency in computing cumsum for each segment. After the scan, the kernel proceeds as before with the nested loops (using unroll pragmas) to compute exp(cumsum[i] - cumsum[j]) for i ≥ j. All threads participate in the __syncthreads calls so that the parallel scan executes correctly. These modifications should improve throughput slightly without any algorithmic workaround. The Python module remains unchanged.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// This kernel implements the fused segsum exponential operator.
// For each input segment of length T, it computes an inclusive cumulative sum (scan)
// in parallel using shared memory, and then for each index pair (i, j) with i >= j,
// it computes: output[i, j] = exp(cumsum[i] - cumsum[j]); otherwise output[i, j] = 0.
// We assume T <= 128. Each block processes one segment using a 2D thread layout.
template <unsigned int TILE_DIM>
__global__ void fused_segsum_exp_kernel(const float* __restrict__ input,
float* __restrict__ output,
int T) {
// Each block corresponds to one segment.
int segment = blockIdx.x;
const float* seg_in = input + segment * T;
float* seg_out = output + segment * T * T;
// Shared memory for cumulative sum; T is assumed <= 128.
__shared__ float cumsum[128];
// Compute a flattened thread index.
int tid = threadIdx.y * blockDim.x + threadIdx.x;
// Load input elements into shared memory in parallel.
if (tid < T) {
cumsum[tid] = seg_in[tid];
}
__syncthreads();
// Parallel inclusive scan over the first T elements.
// All threads (even if tid >= T) participate in __syncthreads to avoid divergence.
for (int offset = 1; offset < T; offset *= 2) {
float temp = 0.0f;
if (tid < T && tid >= offset) {
temp = cumsum[tid - offset];
}
__syncthreads();
if (tid < T) {
cumsum[tid] += temp;
}
__syncthreads();
}
// Now compute the output: for each (i, j) with i, j in [0, T)
int ty = threadIdx.y; // row offset
int tx = threadIdx.x; // column offset
#pragma unroll
for (int i = ty; i < T; i += TILE_DIM) {
#pragma unroll
for (int j = tx; j < T; j += TILE_DIM) {
float val = 0.0f;
if (i >= j) {
val = expf(cumsum[i] - cumsum[j]);
} else {
val = 0.0f;
}
seg_out[i * T + j] = val;
}
}
}
torch::Tensor fused_segsum_exp_forward(torch::Tensor input) {
// Input shape: [R, T], where R is the number of segments.
c10::cuda::CUDAGuard device_guard(input.device());
int R = input.size(0);
int T = input.size(1);
auto output = torch::empty({R, T, T}, input.options());
// Launch the kernel with TILE_DIM = 16.
dim3 blockDim(16, 16);
dim3 gridDim(R);
fused_segsum_exp_kernel<16><<<gridDim, blockDim, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
T
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_segsum_exp_forward, "Fused segsum exponential operator");
}