1.53. Min Reduction Over a Dimension

Calculates the minimum values over a specified axis of the input tensor, effectively reducing the tensor by selecting the smallest element along that dimension.

Progress

Step 5 of 5
Reward
1.85x

Optimization Timeline

Step 5 Details

Reasoning

Our current implementation of the min reduction kernel has been optimized by reshaping the input into three conceptual parts (pre, red, post) and then having each thread compute the minimum value for its assigned (pre, post) slice by iterating over the reduction dimension with loop unrolling by 4 using “#pragma unroll 4.” This approach minimizes loop overhead and improves overall performance. One potential further improvement would involve warp-level synchronization to combine results from threads that might share consecutive elements, but since each output element is computed independently over the “red” dimension, our current design is both simple and effective. In summary, the kernel iterates over the reduction dimension, unrolling the loop to reduce the number of iterations. Our Python wrapper calls the extension via torch.utils.cpp_extension.load using the extra CUDA flags. These improvements have genuinely reduced runtime relative to the native PyTorch implementation. Thus, we supply our final optimized solution.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <float.h>

// Optimized CUDA kernel for min reduction over a specified dimension with loop unrolling.
// The input tensor is assumed to be contiguous and conceptually reshaped as [pre, red, post],
// where 'red' is the size along the reduction dimension.
__global__ void min_reduction_kernel(const float* __restrict__ x,
                                     float* __restrict__ output,
                                     int pre, int red, int post, int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        int pre_idx = idx / post;
        int post_idx = idx % post;
        int base = pre_idx * red * post;
        float min_val = FLT_MAX;
        int i = 0;
        // Unroll the loop 4 times for better performance.
        #pragma unroll 4
        for (; i <= red - 4; i += 4) {
            int idx0 = base + (i + 0) * post + post_idx;
            int idx1 = base + (i + 1) * post + post_idx;
            int idx2 = base + (i + 2) * post + post_idx;
            int idx3 = base + (i + 3) * post + post_idx;
            float v0 = x[idx0];
            float v1 = x[idx1];
            float v2 = x[idx2];
            float v3 = x[idx3];
            float m0 = fminf(v0, v1);
            float m1 = fminf(v2, v3);
            float m2 = fminf(m0, m1);
            min_val = fminf(min_val, m2);
        }
        // Process any remaining elements.
        for (; i < red; i++) {
            int idx_cur = base + i * post + post_idx;
            float v = x[idx_cur];
            min_val = fminf(min_val, v);
        }
        output[idx] = min_val;
    }
}

torch::Tensor min_reduction(torch::Tensor x, int dim) {
    // Ensure we operate on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    auto sizes = x.sizes();
    int ndim = sizes.size();
    TORCH_CHECK(dim >= 0 && dim < ndim, "Reduction dimension out of range.");
    
    // Compute the product of dimensions before, at, and after the reduction dimension.
    int pre = 1;
    for (int i = 0; i < dim; i++) { pre *= sizes[i]; }
    int red = sizes[dim];
    int post = 1;
    for (int i = dim + 1; i < ndim; i++) { post *= sizes[i]; }
    
    // Construct the output shape by removing the reduced dimension.
    std::vector<int64_t> out_shape;
    for (int i = 0; i < ndim; i++) {
        if (i != dim)
            out_shape.push_back(sizes[i]);
    }
    
    auto output = torch::empty(out_shape, x.options());
    
    int total = pre * post;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    min_reduction_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(), output.data_ptr<float>(), pre, red, post, total);
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("min_reduction", &min_reduction,
          "Min reduction over a specified dimension (optimized with loop unrolling) (CUDA)");
}