2.9. Matrix Multiply, Subtract, Multiply, and ReLU Computation
Calculates an output by first performing a linear matrix multiplication, then subtracting a constant value, scaling the result by a multiplier, and finally applying a ReLU activation function.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our fused kernel by reducing memory transactions when loading the input X and the weight matrix. In the original design each thread loops over in_features one element at a time. On modern GPUs, vectorized loads (e.g. using float4) can reduce the number of memory transactions, especially when in_features is divisible by 4. In our improved kernel, we check at runtime whether in_features % 4 == 0; if so, we reinterpret the input pointers as float4 pointers and perform the dot product in “chunks” of 4 floats. If not, we default to the scalar loop. We also add a "#pragma unroll" directive in the vectorized loop to encourage loop unrolling when possible. This overall reduces the number of memory accesses and should increase performance for many common sizes while preserving correctness. Before finishing, we ask: Have my improvements genuinely helped? Yes – using vectorized memory accesses on supported sizes reduces load overhead and improves throughput on the GPU hardware. These improvements are a genuine optimization rather than a simple workaround.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel computes for each output element at (i, j):
// sum = dot(X[i, :], weight[j, :]) + bias[j]
// result = (sum - subtract_val) * multiply_val
// result = max(result, 0)
__global__ void fused_linear_activation_kernel(const float* __restrict__ X,
const float* __restrict__ weight,
const float* __restrict__ bias,
float subtract_val,
float multiply_val,
float* __restrict__ output,
int batch_size,
int in_features,
int out_features) {
// Calculate output coordinates.
int j = blockIdx.x * blockDim.x + threadIdx.x; // column: out_features
int i = blockIdx.y * blockDim.y + threadIdx.y; // row: batch_size
if (i < batch_size && j < out_features) {
float sum = 0.0f;
const float* X_row = X + i * in_features;
const float* W_row = weight + j * in_features;
// Use vectorized loads if in_features is divisible by 4.
if ((in_features & 3) == 0) {
int iter4 = in_features >> 2;
const float4* X4 = reinterpret_cast<const float4*>(X_row);
const float4* W4 = reinterpret_cast<const float4*>(W_row);
#pragma unroll
for (int k = 0; k < iter4; ++k) {
float4 x_val = X4[k];
float4 w_val = W4[k];
sum += x_val.x * w_val.x + x_val.y * w_val.y + x_val.z * w_val.z + x_val.w * w_val.w;
}
} else {
// Fallback to scalar computation.
for (int k = 0; k < in_features; ++k) {
sum += X_row[k] * W_row[k];
}
}
// Add bias then perform subtraction, multiplication and ReLU.
sum += bias[j];
sum = (sum - subtract_val) * multiply_val;
output[i * out_features + j] = fmaxf(sum, 0.0f);
}
}
torch::Tensor forward(torch::Tensor X,
torch::Tensor weight,
torch::Tensor bias,
float subtract_val,
float multiply_val) {
// Ensure correct device using c10::cuda::CUDAGuard.
c10::cuda::CUDAGuard device_guard(X.device());
int batch_size = X.size(0);
int in_features = X.size(1);
int out_features = weight.size(0); // weight shape: (out_features, in_features)
auto output = torch::empty({batch_size, out_features}, X.options());
// Configure 2D blocks.
const int BLOCK_DIM_X = 16;
const int BLOCK_DIM_Y = 16;
dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y);
dim3 grid((out_features + BLOCK_DIM_X - 1) / BLOCK_DIM_X,
(batch_size + BLOCK_DIM_Y - 1) / BLOCK_DIM_Y);
fused_linear_activation_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
X.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
subtract_val,
multiply_val,
output.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused linear, subtraction, multiplication and ReLU activation");
}