2.55. Matrix Multiplication, Pooling, Summation, and Scaling

Calculates a sequence of operations by first performing a matrix multiplication (linear transformation) on the input, then applying a 1D max pooling operation over a reshaped dimension, summing the pooled values across one axis, and finally scaling the result by a given factor.

Progress

Step 3 of 3
Reward
6.78x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our fused kernel by improving the memory access pattern. In our previous version, we already used __ldg for accessing the input vector x. However, since the weight matrix and bias vector are also read-only during the kernel execution, we can use the __ldg intrinsic for those as well. By reading from cached memory for both the input and weight data, we reduce global memory latency and improve performance. Additionally, we unrolled the inner loop for the dot product calculation to reduce loop overhead. Before finalizing, I ask: Have my improvements genuinely helped? Yes, accessing both weight and bias using __ldg should improve throughput and lower the kernel execution time.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cfloat>  // for FLT_MAX

// Fused kernel: each thread handles one sample.
__global__ void fused_forward_kernel(const float* __restrict__ x,
                                     const float* __restrict__ W,
                                     const float* __restrict__ b,
                                     float* __restrict__ output,
                                     int in_features,
                                     int out_features,
                                     int kernel_size,
                                     float scale_factor,
                                     int batch_size) {
    int sample = blockIdx.x * blockDim.x + threadIdx.x;
    if (sample < batch_size) {
        const float* x_row = x + sample * in_features;
        float sum_max = 0.0f;
        // Compute number of pooling windows.
        int window_count = ((out_features - kernel_size) / kernel_size) + 1;
        for (int w = 0; w < window_count; w++) {
            float window_max = -FLT_MAX;
            int start = w * kernel_size;
            int end = start + kernel_size;
            for (int i = start; i < end; i++) {
                float dot_val = 0.0f;
                const float* W_row = W + i * in_features;
                // Unroll inner loop to minimize loop overhead.
                #pragma unroll
                for (int j = 0; j < in_features; j++) {
                    dot_val += __ldg(&x_row[j]) * __ldg(&W_row[j]);
                }
                dot_val += __ldg(&b[i]);
                window_max = fmaxf(window_max, dot_val);
            }
            sum_max += window_max;
        }
        output[sample] = sum_max * scale_factor;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int kernel_size, float scale_factor) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    int batch_size = x.size(0);
    int in_features = x.size(1);
    int out_features = weight.size(0);  // weight: [out_features, in_features]

    // Allocate output tensor of shape [batch]
    auto output = torch::empty({batch_size}, x.options());

    int threads = 256;
    int blocks = (batch_size + threads - 1) / threads;
    
    fused_forward_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        in_features,
        out_features,
        kernel_size,
        scale_factor,
        batch_size
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused linear transformation, max pooling, sum, and scaling operator");
}