1.8. Irregular-Shape Matrix Multiplication

Calculates the product of two input matrices with shapes (M, K) and (K, N), resulting in an output matrix with shape (M, N).

Progress

Step 4 of 4
Reward
1.03x

Optimization Timeline

Step 4 Details

Reasoning

Our earlier cuBLAS-based GEMM call nearly matched the native implementation (15.8 ms vs 15.6 ms) but incurred overhead from creating and destroying the cuBLAS handle on every call. To achieve a speedup over native, we can create a global (or static) cuBLAS handle that is reused across GEMM calls. This avoids the per-call handle creation overhead. We will initialize the handle once and then reuse it for all subsequent calls; we must always set the current CUDA stream before calling cublasSgemm. This change should reduce overhead, yielding a speedup over the native call. The Python module remains largely unchanged; only the CUDA kernel extension is updated with the global handle.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <stdexcept>

// Global persistent cuBLAS handle.
static cublasHandle_t global_handle = nullptr;

// Helper function to initialize the global cuBLAS handle if needed.
void initCublasHandle() {
    if (global_handle == nullptr) {
        cublasStatus_t status = cublasCreate(&global_handle);
        if (status != CUBLAS_STATUS_SUCCESS)
            throw std::runtime_error("Failed to create cuBLAS handle");
    }
}

// This function uses cuBLAS to compute C = A * B, where A is (m x k) and B is (k x n)
// stored in row-major order. We compute the result using a row-major workaround: by calling
// cuBLAS's sgemm to compute the transpose of C as if the matrices were in column-major order.
// The call below computes C_temp, of dimensions (n x m) in column-major order, which corresponds
// to C (m x n) in row-major.
torch::Tensor gemm(torch::Tensor A, torch::Tensor B) {
    // Ensure operations are on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(A.device());
    
    // Check dimensions.
    TORCH_CHECK(A.dim() == 2, "A must be a 2D tensor");
    TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
    int m = A.size(0);
    int k = A.size(1);
    TORCH_CHECK(B.size(0) == k, "Inner dimensions of A and B must match");
    int n = B.size(1);
    
    // Allocate output tensor C (desired shape: m x n).
    auto C = torch::empty({m, n}, A.options());
    
    // Initialize the global cuBLAS handle if necessary.
    initCublasHandle();
    
    // Set the current CUDA stream for the global handle.
    cublasStatus_t status = cublasSetStream(global_handle, c10::cuda::getCurrentCUDAStream());
    if (status != CUBLAS_STATUS_SUCCESS)
        throw std::runtime_error("Failed to set cuBLAS stream");
    
    float alpha = 1.0f;
    float beta  = 0.0f;
    
    // Using the row-major to column-major trick:
    //   C = A * B   (row-major)
    // is equivalent to:
    //   C^T = B^T * A^T   (column-major)
    // We call cublasSgemm with op = CUBLAS_OP_N on both matrices.
    // Leading dimensions: for B, ld = n; for A, ld = k; for C, ld = n.
    status = cublasSgemm(global_handle,
                         CUBLAS_OP_N, CUBLAS_OP_N,
                         n,   // rows of C^T
                         m,   // columns of C^T
                         k,   // shared dimension
                         &alpha,
                         B.data_ptr<float>(), n,   // B: (k x n) in row-major => B^T: (n x k), ld = n
                         A.data_ptr<float>(), k,   // A: (m x k) in row-major => A^T: (k x m), ld = k
                         &beta,
                         C.data_ptr<float>(), n);  // C: treated as (m x n) row-major, but computed as C^T with ld = n
    if (status != CUBLAS_STATUS_SUCCESS)
        throw std::runtime_error("cublasSgemm failed");
        
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("gemm", &gemm, "Matrix multiplication using cuBLAS sgemm with persistent handle (row-major workaround)");
}