2.68. Matmul-Min-Subtract Operation

Calculates a linear transformation on the input via matrix multiplication, then applies an element-wise minimum operation with a constant value, and finally subtracts that constant from each element of the result.

Progress

Step 2 of 2
Reward
2.51x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused kernel by reducing loop overhead in the inner loop that computes the dot-product. Since in_features is typically small (e.g., 10 in our test case), we can use "#pragma unroll" to instruct the compiler to unroll the loop, reducing the loop overhead. In addition, using the __ldg intrinsic to load the weight and bias values signals that these arrays are read‐only and may reside in the read-only cache, possibly reducing global memory latency. These changes help to further reduce the kernel execution time. The Python module remains unchanged except for the more optimized kernel beneath.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Fused kernel: Each thread computes one output element of the linear layer,
// then applies: y = min(dot_product + bias, constant) - constant.
__global__ void fused_linear_min_sub_kernel(const float* __restrict__ X,
                                              const float* __restrict__ W,
                                              const float* __restrict__ b,
                                              float constant,
                                              float* __restrict__ Y,
                                              int batch,
                                              int in_features,
                                              int out_features) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch * out_features;
    if(index < total) {
        int i = index / out_features;   // Index into batch dimension.
        int j = index % out_features;     // Index into output features.
        float sum = 0.0f;
        // Unroll the loop to reduce overhead since in_features is small.
        #pragma unroll
        for (int k = 0; k < in_features; k++) {
            sum += X[i * in_features + k] * __ldg(&W[j * in_features + k]);
        }
        sum += __ldg(&b[j]);
        float min_val = fminf(sum, constant);
        Y[index] = min_val - constant;
    }
}

torch::Tensor forward(torch::Tensor X, torch::Tensor W, torch::Tensor b, float constant) {
    // Ensure we work on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(X.device());
    
    int batch = X.size(0);
    int in_features = X.size(1);
    int out_features = W.size(0);
    
    // Allocate output tensor Y of shape [batch, out_features].
    auto Y = torch::empty({batch, out_features}, X.options());
    
    int total = batch * out_features;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_linear_min_sub_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        X.data_ptr<float>(),
        W.data_ptr<float>(),
        b.data_ptr<float>(),
        constant,
        Y.data_ptr<float>(),
        batch,
        in_features,
        out_features
    );
    
    return Y;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused linear matrix multiplication with min and subtraction (CUDA)");
}