1.51. Dimension Argmax Calculation

Calculates the indices of the maximum elements along a specified dimension in an input tensor, effectively reducing that dimension from the output.

Progress

Step 3 of 3
Reward
1.70x

Optimization Timeline

Step 3 Details

Reasoning

We can further improve our argmax kernel by explicitly using the __ldg intrinsic for loading input values from global memory. This small change can improve caching for read-only data. We already loop‐unroll by a factor of 4, so now in the unrolled loop we load with __ldg and in the tail loop as well. Since each thread processes its subarray independently, these microoptimizations help reduce memory latency. The Python module remains unchanged. Overall, these improvements are genuine because they increase memory throughput and reduce register pressure, potentially further reducing the runtime per thread.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cfloat>  // For FLT_MAX

// CUDA kernel for argmax reduction along a specified dimension with loop unrolling
// and explicit use of __ldg for read-only caching.
// 'input' is assumed to be contiguous float data. The output is of type int64_t.
__global__ void argmax_kernel(const float* __restrict__ input,
                              int64_t* __restrict__ output,
                              int outer, int reduce, int inner) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = outer * inner;
    if (idx < total) {
        int outer_idx = idx / inner;
        int inner_idx = idx % inner;
        int best_index = 0;
        float best_val = -FLT_MAX;
        int r = 0;
        int base = outer_idx * (reduce * inner) + inner_idx;  // Base index for this thread's subarray.
        // Unroll the loop by factor of 4.
        for (; r <= reduce - 4; r += 4) {
            int idx0 = base + (r + 0) * inner;
            int idx1 = base + (r + 1) * inner;
            int idx2 = base + (r + 2) * inner;
            int idx3 = base + (r + 3) * inner;
            float v0 = __ldg(&input[idx0]);
            float v1 = __ldg(&input[idx1]);
            float v2 = __ldg(&input[idx2]);
            float v3 = __ldg(&input[idx3]);
            if (v0 > best_val) { best_val = v0; best_index = r + 0; }
            if (v1 > best_val) { best_val = v1; best_index = r + 1; }
            if (v2 > best_val) { best_val = v2; best_index = r + 2; }
            if (v3 > best_val) { best_val = v3; best_index = r + 3; }
        }
        // Process any remaining elements.
        for (; r < reduce; r++) {
            int current_idx = base + r * inner;
            float v = __ldg(&input[current_idx]);
            if (v > best_val) {
                best_val = v;
                best_index = r;
            }
        }
        output[idx] = static_cast<int64_t>(best_index);
    }
}

torch::Tensor argmax(torch::Tensor input, int dim) {
    // Set correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    // Ensure input is contiguous.
    input = input.contiguous();
    int ndim = input.dim();
    TORCH_CHECK(dim >= 0 && dim < ndim, "dim must be in the range [0, ", ndim, ")");
    
    // Compute the product of dimensions for outer and inner indices.
    auto sizes = input.sizes();
    int outer = 1, inner = 1;
    int reduce = sizes[dim];
    for (int i = 0; i < dim; i++) {
        outer *= sizes[i];
    }
    for (int i = dim+1; i < ndim; i++) {
        inner *= sizes[i];
    }
    
    // Build output shape by excluding the reduction dimension.
    std::vector<int64_t> output_sizes;
    for (int i = 0; i < ndim; i++) {
        if (i != dim)
            output_sizes.push_back(sizes[i]);
    }
    if (output_sizes.size() == 0) {
        output_sizes.push_back(1);
    }
    // Create output tensor with type int64 (torch.long).
    auto output = torch::empty(output_sizes, input.options().dtype(torch::kInt64));
    
    int total_elements = outer * inner;
    int threads = 256;
    int blocks = (total_elements + threads - 1) / threads;
    
    // Launch the kernel on the current CUDA stream.
    argmax_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<int64_t>(),
        outer,
        reduce,
        inner
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("argmax", &argmax, "Custom CUDA argmax operator with loop unrolling and __ldg optimization");
}