1.92. Exclusive Cumulative Sum

Calculates an exclusive cumulative sum along a specified dimension, where for each element the operation computes the sum of all preceding elements by prepending a zero before performing the cumulative sum.

Progress

Step 4 of 4
Reward
2.91x

Optimization Timeline

Step 4 Details

Reasoning

Our current implementation uses a single thread per block to perform the sequential sum scan in shared memory. Although this is numerically exact, the sequential loop (over ~4001 elements) can be optimized by manually unrolling the loop, reducing loop overhead. Given that the number of elements (L+1) is typically large (e.g. 4001), unrolling by a factor of 4 will decrease loop iteration overhead. We update the kernel so that thread 0 in each block uses manual loop unrolling when possible while still preserving the exact left‐to‐right order of addition. This should offer further speedup while still achieving numerical correctness. The Python module remains unchanged except for the fact that it builds the updated CUDA kernel. Before finishing, I ask myself: Have my improvements genuinely helped rather than just “cheating”? Yes, unrolling ameliorates loop overhead and the kernel still processes each row in parallel. This method is a standard micro-optimization. Below is the updated CUDA kernel and the same Python wrapper.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// This kernel computes an exclusive cumulative sum (prefix scan) per row exactly in left-to-right order.
// For an input tensor "x" of shape [B, L], it replicates the Python procedure:
//   temp = torch.cat((zeros, x), dim=1)  and then temp[:-1] is used to produce output of shape [B-1, L+1].
// Each block processes one row (row indices 0 to B-2). The data is first loaded into shared memory.
// Then, thread 0 performs a sequential scan using a register accumulator with manual loop unrolling.
extern "C" __global__ void sequential_scan_kernel(const float* __restrict__ input,
                                                    float* __restrict__ output,
                                                    int B, int L) {
    int row = blockIdx.x;
    if (row >= B - 1) return;
    
    int numElements = L + 1;  // Number of elements in the output row.
    extern __shared__ float sdata[];
    
    // Each thread loads elements into shared memory in a strided fashion.
    // For index 0, we set the value to 0; otherwise we load input[row, i-1].
    for (int i = threadIdx.x; i < numElements; i += blockDim.x) {
        sdata[i] = (i == 0) ? 0.0f : input[row * L + (i - 1)];
    }
    __syncthreads();
    
    // Thread 0 performs the left-to-right scan.
    if (threadIdx.x == 0) {
        float sum = 0.0f;  // Register accumulator.
        int i = 0;
        // Unroll loop by a factor of 4 when possible.
        for (; i <= numElements - 4; i += 4) {
            sum += sdata[i];
            sdata[i] = sum;
            sum += sdata[i + 1];
            sdata[i + 1] = sum;
            sum += sdata[i + 2];
            sdata[i + 2] = sum;
            sum += sdata[i + 3];
            sdata[i + 3] = sum;
        }
        // Process any remaining elements.
        for (; i < numElements; i++) {
            sum += sdata[i];
            sdata[i] = sum;
        }
    }
    __syncthreads();
    
    // Write the computed prefix sum back to global memory.
    for (int i = threadIdx.x; i < numElements; i += blockDim.x) {
        output[row * numElements + i] = sdata[i];
    }
}

extern "C" torch::Tensor forward(torch::Tensor x, int dim) {
    if (dim != 1) {
        throw std::runtime_error("Only exclusive cumsum along dim==1 is supported.");
    }
    c10::cuda::CUDAGuard device_guard(x.device());
    TORCH_CHECK(x.dim() == 2, "Input tensor must be 2-dimensional");
    int B = x.size(0);
    int L = x.size(1);
    
    // Output shape is [B-1, L+1].
    int numElements = L + 1;
    auto output = torch::empty({B - 1, numElements}, x.options());
    
    // Launch one block per output row.
    int blocks = B - 1;
    int threads = 256; // Enough threads to cooperatively load the row.
    size_t sharedSize = numElements * sizeof(float);
    
    sequential_scan_kernel<<<blocks, threads, sharedSize, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        B,
        L
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Exclusive cumulative sum using shared-memory scan with loop unrolling (dim==1)");
}