1.93. Masked Cumulative Sum

Calculates a cumulative sum along a specified dimension by including only the elements that satisfy a given boolean condition.

Progress

Step 10 of 10
Reward
1.86x

Optimization Timeline

Step 10 Details

Reasoning

We can further optimize the kernel by replacing magic numbers (like 32 for warp size) with constants to improve clarity and possibly allow the compiler to optimize the code better. Hence, we introduce a constexpr value for warp size and replace all instances of the constant “32” with this value. Other than that, the design remains the same: each block processes a row, each thread computes a local cumulative sum over its segment, then a two-level exclusive scan is performed (intra-warp using __shfl_up_sync and inter-warp via a serial scan on a small shared array). We retain __ldg for float loads but avoid it for bool. The Python wrapper remains unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <vector>

// Define the warp size as a constant.
constexpr int warpSize = 32;

// Optimized masked cumulative sum kernel using a two-level (warp + block) exclusive scan.
// Assumes the cumulative dimension is the last dimension of the input (after permutation if needed).
// Each block processes one row and blockDim.x should be a power of two (e.g., 256).
__global__ void optimized_masked_cumsum_kernel(const float* __restrict__ x,
                                               const bool* __restrict__ mask,
                                               float* __restrict__ out,
                                               int row_size,
                                               int total_rows) {
    // Each block handles one row.
    int row = blockIdx.x;
    if (row >= total_rows) return;
    
    const float* x_row = x + row * row_size;
    const bool* mask_row = mask + row * row_size;
    float* out_row = out + row * row_size;
    
    int tid = threadIdx.x;
    int num_threads = blockDim.x;  // e.g., 256
    // Divide row elements among threads.
    int elems_per_thread = (row_size + num_threads - 1) / num_threads;
    int start = tid * elems_per_thread;
    int end = min(start + elems_per_thread, row_size);
    
    // Each thread computes a local cumulative sum over its segment.
    // We assume elems_per_thread <= 64.
    const int max_elems_per_thread = 64;
    float local_scan[max_elems_per_thread];
    int count = end - start;
    float local_total = 0.0f;
    #pragma unroll
    for (int i = 0; i < count; i++) {
        // Use __ldg to load float data from global memory.
        float val = __ldg(x_row + start + i);
        // Do not use __ldg for booleans.
        bool m = mask_row[start + i];
        float masked_val = m ? val : 0.0f;
        local_total += masked_val;
        local_scan[i] = local_total;
    }
    
    // --- Warp-level exclusive scan within each warp:
    const unsigned int FULL_MASK = 0xffffffff;
    int lane = tid & (warpSize - 1);    // lane index in the warp
    int warpId = tid / warpSize;          // warp index within the block
    
    float warp_sum = local_total;
    // Intra-warp inclusive scan using shuffle.
    #pragma unroll
    for (int offset = 1; offset < warpSize; offset *= 2) {
        float n = __shfl_up_sync(FULL_MASK, warp_sum, offset);
        if (lane >= offset) {
            warp_sum += n;
        }
    }
    // Convert to exclusive: thread's exclusive value = inclusive result minus its contribution.
    float thread_exclusive = (lane == 0) ? 0.0f : warp_sum - local_total;
    
    // --- Compute warp offsets across the block.
    // Each warp's last lane writes its warp total into shared memory.
    __shared__ float warpSums[32];  // supports up to 1024 threads (32 warps)
    if (lane == warpSize - 1) {
        warpSums[warpId] = warp_sum;
    }
    __syncthreads();
    
    // Perform an exclusive scan over warpSums serially in one thread.
    __shared__ float warpEx[32];
    if (tid == 0) {
        int num_warps = (num_threads + warpSize - 1) / warpSize;
        warpEx[0] = 0.0f;
        for (int w = 1; w < num_warps; w++) {
            warpEx[w] = warpEx[w - 1] + warpSums[w - 1];
        }
    }
    __syncthreads();
    
    // Global offset for this thread = its warp's offset + its intra-warp exclusive value.
    float global_offset = warpEx[warpId] + thread_exclusive;
    
    // Write back the results for this thread's segment.
    #pragma unroll
    for (int i = 0; i < count; i++) {
        out_row[start + i] = local_scan[i] + global_offset;
    }
}

// Host function: permutes tensors if needed and launches the kernel.
torch::Tensor masked_cumsum(torch::Tensor x, torch::Tensor mask, int64_t dim) {
    // Ensure x and mask have the same shape and are CUDA tensors.
    TORCH_CHECK(x.sizes() == mask.sizes(), "x and mask must have the same shape");
    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
    TORCH_CHECK(mask.is_cuda(), "mask must be a CUDA tensor");
    
    const c10::cuda::CUDAGuard device_guard(x.device());
    
    // If the cumulative dimension is not the last, permute x and mask.
    bool did_permute = false;
    std::vector<int64_t> orig_dims;
    for (int i = 0; i < x.dim(); i++) {
        orig_dims.push_back(i);
    }
    torch::Tensor x_proc = x, mask_proc = mask;
    std::vector<int64_t> inv_perm;
    if (dim != x.dim() - 1) {
        std::vector<int64_t> perm = orig_dims;
        std::swap(perm[dim], perm[x.dim() - 1]);
        x_proc = x.permute(perm).contiguous();
        mask_proc = mask.permute(perm).contiguous();
        inv_perm.resize(perm.size());
        for (size_t i = 0; i < perm.size(); i++) {
            inv_perm[perm[i]] = i;
        }
        did_permute = true;
    }
    
    // Now, the cumulative dimension is the last.
    int row_size = x_proc.size(x_proc.dim() - 1);
    int total_rows = x_proc.numel() / row_size;
    auto out = torch::empty_like(x_proc);
    
    // Launch one block per row with a fixed number of threads (e.g., 256).
    int threads = 256;
    dim3 grid(total_rows);
    dim3 block(threads);
    optimized_masked_cumsum_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x_proc.data_ptr<float>(),
        mask_proc.data_ptr<bool>(),
        out.data_ptr<float>(),
        row_size,
        total_rows
    );
    
    // If tensors were permuted, invert the output permutation.
    if (did_permute) {
        out = out.permute(inv_perm);
    }
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("masked_cumsum", &masked_cumsum, "Optimized masked cumulative sum using two-level warp scan (CUDA)");
}