2.18. Matrix Multiply and Reduction Sequence

Performs a chain of tensor operations: it starts with a matrix multiplication, then applies a series of reductions including summation, maximum, average pooling, and two successive log-sum-exp operations to compress the input data.

Progress

Step 2 of 2
Reward
3.34x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize our fused kernel by slightly unrolling the inner loop over the input features. Since the reduction happens over the in_features dimension and typical values (like 10 in the test) are small, the compiler can unroll the loop with a "#pragma unroll" directive. This reduces loop overhead and can result in faster execution. The overall approach remains: • For each sample (row in the input tensor), compute the dot product with the vector V (which is the sum over the rows of the weight matrix) and add the bias sum “c”. • Return the resulting scalar for each sample. This fusion combines the matrix multiplication, summation, max, average pooling, and two LogSumExp operations into one efficient dot-product computation. This improvement genuinely helps by reducing branch and loop overhead for small in_features even further.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Kernel to compute, for each sample (row in x), the dot product with vector V plus scalar c.
// x: input tensor of shape [batch_size, in_features]
// V: vector of shape [in_features]
// c: scalar value
// out: output tensor of shape [batch_size]
__global__ void fused_linear_sum_kernel(const float* __restrict__ x,
                                          const float* __restrict__ V,
                                          float c,
                                          float* __restrict__ out,
                                          int in_features,
                                          int batch_size) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < batch_size) {
        float sum = 0.0f;
        // Unroll loop for potential performance gains on small in_features.
        #pragma unroll
        for (int k = 0; k < in_features; k++) {
            sum += x[i * in_features + k] * V[k];
        }
        out[i] = sum + c;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor V, float c) {
    // Ensure that operations run on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());

    int batch_size = x.size(0);
    int in_features = x.size(1);

    // Allocate output tensor of shape [batch_size].
    auto out = torch::empty({batch_size}, x.options());

    int total_threads = batch_size;
    int threads = 256;
    int blocks = (total_threads + threads - 1) / threads;

    fused_linear_sum_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        V.data_ptr<float>(),
        c,
        out.data_ptr<float>(),
        in_features,
        batch_size
    );

    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused linear transformation sum (dot(x, sum(W)) + sum(b))");
}