2.68. Matmul-Min-Subtract Operation
Calculates a linear transformation on the input via matrix multiplication, then applies an element-wise minimum operation with a constant value, and finally subtracts that constant from each element of the result.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused kernel by reducing loop overhead in the inner loop that computes the dot-product. Since in_features is typically small (e.g., 10 in our test case), we can use "#pragma unroll" to instruct the compiler to unroll the loop, reducing the loop overhead. In addition, using the __ldg intrinsic to load the weight and bias values signals that these arrays are read‐only and may reside in the read-only cache, possibly reducing global memory latency. These changes help to further reduce the kernel execution time. The Python module remains unchanged except for the more optimized kernel beneath.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Fused kernel: Each thread computes one output element of the linear layer,
// then applies: y = min(dot_product + bias, constant) - constant.
__global__ void fused_linear_min_sub_kernel(const float* __restrict__ X,
const float* __restrict__ W,
const float* __restrict__ b,
float constant,
float* __restrict__ Y,
int batch,
int in_features,
int out_features) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * out_features;
if(index < total) {
int i = index / out_features; // Index into batch dimension.
int j = index % out_features; // Index into output features.
float sum = 0.0f;
// Unroll the loop to reduce overhead since in_features is small.
#pragma unroll
for (int k = 0; k < in_features; k++) {
sum += X[i * in_features + k] * __ldg(&W[j * in_features + k]);
}
sum += __ldg(&b[j]);
float min_val = fminf(sum, constant);
Y[index] = min_val - constant;
}
}
torch::Tensor forward(torch::Tensor X, torch::Tensor W, torch::Tensor b, float constant) {
// Ensure we work on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(X.device());
int batch = X.size(0);
int in_features = X.size(1);
int out_features = W.size(0);
// Allocate output tensor Y of shape [batch, out_features].
auto Y = torch::empty({batch, out_features}, X.options());
int total = batch * out_features;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_linear_min_sub_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
X.data_ptr<float>(),
W.data_ptr<float>(),
b.data_ptr<float>(),
constant,
Y.data_ptr<float>(),
batch,
in_features,
out_features
);
return Y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused linear matrix multiplication with min and subtraction (CUDA)");
}