2.18. Matrix Multiply and Reduction Sequence
Performs a chain of tensor operations: it starts with a matrix multiplication, then applies a series of reductions including summation, maximum, average pooling, and two successive log-sum-exp operations to compress the input data.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our fused kernel by slightly unrolling the inner loop over the input features. Since the reduction happens over the in_features dimension and typical values (like 10 in the test) are small, the compiler can unroll the loop with a "#pragma unroll" directive. This reduces loop overhead and can result in faster execution. The overall approach remains: • For each sample (row in the input tensor), compute the dot product with the vector V (which is the sum over the rows of the weight matrix) and add the bias sum “c”. • Return the resulting scalar for each sample. This fusion combines the matrix multiplication, summation, max, average pooling, and two LogSumExp operations into one efficient dot-product computation. This improvement genuinely helps by reducing branch and loop overhead for small in_features even further.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Kernel to compute, for each sample (row in x), the dot product with vector V plus scalar c.
// x: input tensor of shape [batch_size, in_features]
// V: vector of shape [in_features]
// c: scalar value
// out: output tensor of shape [batch_size]
__global__ void fused_linear_sum_kernel(const float* __restrict__ x,
const float* __restrict__ V,
float c,
float* __restrict__ out,
int in_features,
int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
float sum = 0.0f;
// Unroll loop for potential performance gains on small in_features.
#pragma unroll
for (int k = 0; k < in_features; k++) {
sum += x[i * in_features + k] * V[k];
}
out[i] = sum + c;
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor V, float c) {
// Ensure that operations run on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
int batch_size = x.size(0);
int in_features = x.size(1);
// Allocate output tensor of shape [batch_size].
auto out = torch::empty({batch_size}, x.options());
int total_threads = batch_size;
int threads = 256;
int blocks = (total_threads + threads - 1) / threads;
fused_linear_sum_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
V.data_ptr<float>(),
c,
out.data_ptr<float>(),
in_features,
batch_size
);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused linear transformation sum (dot(x, sum(W)) + sum(b))");
}