2.14. Gemm-Divide-Sum-Scaling Operation
Calculates a transformation on input data by first multiplying an input tensor with a transposed weight matrix, then dividing the result by 2, summing each row to aggregate the values, and finally scaling the aggregated result by a constant factor.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We can fuse the operations (matrix multiply, division, summation, and scaling) into a single custom CUDA kernel. Notice that the original model does the following: 1. Computes Y = x * weight^T by "gemm": for each sample i and each row j in weight, Y[i,j] = ∑ₖ x[i,k]*weight[j,k] 2. Divides Y by 2. 3. Reduces Y by summing over the hidden dimension (i.e. over j), producing a tensor of shape (batch_size, 1). 4. Finally multiplies by scaling_factor. We can mathematically rearrange this fused computation. In fact: output[i] = scaling_factor * (1/2) * ∑ⱼ (∑ₖ x[i,k]*weight[j,k]) = scaling_factor * (1/2) * ∑ₖ (x[i,k] * (∑ⱼ weight[j,k])) Notice that the inner sum (∑ⱼ weight[j,k]) can be pre-computed for each k, reducing the cost of the nested loops. However, since the weight matrix is relatively small (e.g., hidden_size=20 in our example) and because we want a fully fused operator, we compute the weighted sum for each input sample within one kernel invocation. We launch one thread per sample to loop over input_size and the hidden dimension. The final output tensor has shape (batch_size, 1). We use c10::cuda::CUDAGuard to ensure device correctness and c10::cuda::getCurrentCUDAStream() for launching the kernel. The Python module loads the CUDA extension with optimized flags (-O3, --use_fast_math) and passes x, weight, scaling_factor, input_size, and hidden_size to the fused operator. This fused approach reduces memory traffic and kernel launch overhead compared to performing multiple separate operations and should result in significant speedups over a naive native implementation.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Fused kernel for matrix multiplication, division, summation, and scaling.
// Each thread processes one sample.
__global__ void fused_gemm_div_sum_scale_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
float scaling_factor,
int input_size,
int hidden_size,
float* __restrict__ output,
int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
float acc = 0.0f;
// For each element in the input vector of sample i:
for (int k = 0; k < input_size; k++) {
float x_val = x[i * input_size + k];
float sum_w = 0.0f;
// Sum over the hidden dimension for the weight matrix at column k.
for (int j = 0; j < hidden_size; j++) {
sum_w += weight[j * input_size + k];
}
acc += x_val * sum_w;
}
// Division by 2 and scaling.
output[i] = acc * 0.5f * scaling_factor;
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor weight, float scaling_factor, int input_size, int hidden_size) {
// Ensure correct device usage.
c10::cuda::CUDAGuard device_guard(x.device());
// x is of shape (batch_size, input_size)
int batch_size = x.size(0);
// Allocate output tensor with shape (batch_size, 1)
auto output = torch::empty({batch_size, 1}, x.options());
// Launch kernel: one thread per sample.
int threads = 256;
int blocks = (batch_size + threads - 1) / threads;
fused_gemm_div_sum_scale_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
scaling_factor,
input_size,
hidden_size,
output.data_ptr<float>(),
batch_size
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused GEMM, division by 2, summation and scaling");
}