1.8. Irregular-Shape Matrix Multiplication
Calculates the product of two input matrices with shapes (M, K) and (K, N), resulting in an output matrix with shape (M, N).
Progress
Optimization Timeline
Step 4 Details
Reasoning
Our earlier cuBLAS-based GEMM call nearly matched the native implementation (15.8 ms vs 15.6 ms) but incurred overhead from creating and destroying the cuBLAS handle on every call. To achieve a speedup over native, we can create a global (or static) cuBLAS handle that is reused across GEMM calls. This avoids the per-call handle creation overhead. We will initialize the handle once and then reuse it for all subsequent calls; we must always set the current CUDA stream before calling cublasSgemm. This change should reduce overhead, yielding a speedup over the native call. The Python module remains largely unchanged; only the CUDA kernel extension is updated with the global handle.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <stdexcept>
// Global persistent cuBLAS handle.
static cublasHandle_t global_handle = nullptr;
// Helper function to initialize the global cuBLAS handle if needed.
void initCublasHandle() {
if (global_handle == nullptr) {
cublasStatus_t status = cublasCreate(&global_handle);
if (status != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("Failed to create cuBLAS handle");
}
}
// This function uses cuBLAS to compute C = A * B, where A is (m x k) and B is (k x n)
// stored in row-major order. We compute the result using a row-major workaround: by calling
// cuBLAS's sgemm to compute the transpose of C as if the matrices were in column-major order.
// The call below computes C_temp, of dimensions (n x m) in column-major order, which corresponds
// to C (m x n) in row-major.
torch::Tensor gemm(torch::Tensor A, torch::Tensor B) {
// Ensure operations are on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(A.device());
// Check dimensions.
TORCH_CHECK(A.dim() == 2, "A must be a 2D tensor");
TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
int m = A.size(0);
int k = A.size(1);
TORCH_CHECK(B.size(0) == k, "Inner dimensions of A and B must match");
int n = B.size(1);
// Allocate output tensor C (desired shape: m x n).
auto C = torch::empty({m, n}, A.options());
// Initialize the global cuBLAS handle if necessary.
initCublasHandle();
// Set the current CUDA stream for the global handle.
cublasStatus_t status = cublasSetStream(global_handle, c10::cuda::getCurrentCUDAStream());
if (status != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("Failed to set cuBLAS stream");
float alpha = 1.0f;
float beta = 0.0f;
// Using the row-major to column-major trick:
// C = A * B (row-major)
// is equivalent to:
// C^T = B^T * A^T (column-major)
// We call cublasSgemm with op = CUBLAS_OP_N on both matrices.
// Leading dimensions: for B, ld = n; for A, ld = k; for C, ld = n.
status = cublasSgemm(global_handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n, // rows of C^T
m, // columns of C^T
k, // shared dimension
&alpha,
B.data_ptr<float>(), n, // B: (k x n) in row-major => B^T: (n x k), ld = n
A.data_ptr<float>(), k, // A: (m x k) in row-major => A^T: (k x m), ld = k
&beta,
C.data_ptr<float>(), n); // C: treated as (m x n) row-major, but computed as C^T with ld = n
if (status != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("cublasSgemm failed");
return C;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gemm", &gemm, "Matrix multiplication using cuBLAS sgemm with persistent handle (row-major workaround)");
}