1.5. Matrix Scalar Multiplication

Calculates the element-wise multiplication of a matrix with a scalar value, producing a new matrix where each element is scaled by the given factor.

Progress

Step 2 of 2
Reward
1.03x

Optimization Timeline

Step 2 Details

Reasoning

The original kernel performs element‐wise multiplication of a matrix A with a scalar s by assigning one thread per element. Since the matrix size is huge (16384×4096) and memory is contiguous, we can optimize further by vectorizing memory accesses. By processing four floats at a time via float4, we reduce the number of loads and stores and improve memory throughput. In the host function we check if the total number of elements (A.numel()) is divisible by 4 and, if so, dispatch the vectorized kernel; otherwise the fallback scalar kernel is used. We still use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() for proper device handling. This optimization should reduce the runtime compared to the native multiplication even further.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Kernel for element-wise matrix-scalar multiplication (fallback).
__global__ void mat_scalar_kernel(const float* __restrict__ A,
                                  float s,
                                  float* __restrict__ C,
                                  int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        C[idx] = A[idx] * s;
    }
}

// Optimized vectorized kernel using float4 to process four floats at a time.
__global__ void mat_scalar_vec_kernel(const float4* __restrict__ A,
                                      float s,
                                      float4* __restrict__ C,
                                      int total_vec) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total_vec) {
        float4 a = A[idx];
        float4 c;
        c.x = a.x * s;
        c.y = a.y * s;
        c.z = a.z * s;
        c.w = a.w * s;
        C[idx] = c;
    }
}

torch::Tensor forward(torch::Tensor A, float s) {
    // Ensure correct device usage.
    c10::cuda::CUDAGuard device_guard(A.device());
    
    int total = A.numel();
    auto C = torch::empty_like(A);
    
    int threads = 256;
    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    
    // If total is divisible by 4, use vectorized kernel.
    if (total % 4 == 0) {
        int total_vec = total / 4;
        int blocks = (total_vec + threads - 1) / threads;
        // Cast data pointers to float4.
        const float4* A_vec = reinterpret_cast<const float4*>(A.data_ptr<float>());
        float4* C_vec = reinterpret_cast<float4*>(C.data_ptr<float>());
        mat_scalar_vec_kernel<<<blocks, threads, 0, stream>>>(
            A_vec, s, C_vec, total_vec
        );
    } else {
        int blocks = (total + threads - 1) / threads;
        mat_scalar_kernel<<<blocks, threads, 0, stream>>>(
            A.data_ptr<float>(), s, C.data_ptr<float>(), total
        );
    }
    
    return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Matrix-scalar multiplication (C = A * s) using optimized CUDA kernels");
}