2.95. Matrix Multiply, Add and Sequential Activations
Performs a matrix multiplication on the input, adds an offset parameter element-wise, and then applies a chain of activation functions in order: first Swish (multiplying the tensor by its sigmoid), then hyperbolic tangent (Tanh), followed by the Gaussian Error Linear Unit (GELU) using a tanh-based approximation, and finally Hardtanh which clamps output values between -1 and 1.
Progress
Optimization Timeline
Step 10 Details
Reasoning
Our current implementation uses advanced optimization techniques: a tiled GEMM kernel with shared memory (including padding to avoid bank conflicts), unrolled inner loops, fused multiply–add (fmaf) operations, and __ldg intrinsics for read-only global memory accesses. These optimizations have reduced the runtime of the fused operation from ~0.217 seconds (native) to ~0.113 seconds, achieving nearly a 1.92× speedup. At this point, further improvements would require either substantially different strategies (such as warp-level primitives or mixed-precision arithmetic) which would increase complexity, or yield diminishing returns. Our current optimizations have genuinely helped by reducing memory latency and computational overhead without resorting to ad hoc workarounds. Therefore, we consider the implementation optimized as is. I now provide the final version of the Python module and CUDA kernel.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#define TILE_M 16
#define TILE_N 16
#define TILE_K 16
// Fast sigmoid function.
__device__ inline float sigmoidf(float x) {
return 1.0f / (1.0f + expf(-x));
}
// Approximate GELU function.
__device__ inline float gelu_approx(float x) {
const float sqrt_2_over_pi = 0.7978845608028654f; // sqrt(2/pi)
const float coeff = 0.044715f;
float x3 = x * x * x;
float tanh_out = tanhf(sqrt_2_over_pi * (x + coeff * x3));
return 0.5f * x * (1.0f + tanh_out);
}
__global__ void fused_activation_chain_tiled_kernel(const float* __restrict__ X,
const float* __restrict__ Weight,
const float* __restrict__ Bias,
const float* __restrict__ Add_value,
float* __restrict__ Y,
int batch_size,
int in_features,
int out_features) {
// Global indices.
int row = blockIdx.x * TILE_M + threadIdx.y; // Batch index.
int col = blockIdx.y * TILE_N + threadIdx.x; // Output feature index.
float acc = 0.0f;
int numTiles = (in_features + TILE_K - 1) / TILE_K;
// Declare shared memory tiles with extra column for padding.
__shared__ float sX[TILE_M][TILE_K + 1];
__shared__ float sW[TILE_N][TILE_K + 1]; // Loaded transposed.
for (int t = 0; t < numTiles; t++) {
int k_index = t * TILE_K + threadIdx.x;
if (row < batch_size && k_index < in_features) {
// Load using __ldg for improved read-only global access.
sX[threadIdx.y][threadIdx.x] = __ldg(&X[row * in_features + k_index]);
} else {
sX[threadIdx.y][threadIdx.x] = 0.0f;
}
int k_index_w = t * TILE_K + threadIdx.y;
if (col < out_features && k_index_w < in_features) {
sW[threadIdx.x][threadIdx.y] = __ldg(&Weight[col * in_features + k_index_w]);
} else {
sW[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
// Compute partial dot product for this tile.
#pragma unroll
for (int k = 0; k < TILE_K; k++) {
acc = fmaf(sX[threadIdx.y][k], sW[threadIdx.x][k], acc);
}
__syncthreads();
}
if (row < batch_size && col < out_features) {
// Add Bias and extra Add_value.
acc += Bias[col] + Add_value[col];
// Apply Swish activation.
float swish = sigmoidf(acc) * acc;
// Apply Tanh.
float tanh_val = tanhf(swish);
// Apply approximate GELU.
float gelu_val = gelu_approx(tanh_val);
// Clamp (Hardtanh) between -1 and 1.
float result = fminf(fmaxf(gelu_val, -1.0f), 1.0f);
Y[row * out_features + col] = result;
}
}
torch::Tensor forward(torch::Tensor X, torch::Tensor Weight, torch::Tensor Bias, torch::Tensor Add_value) {
// Ensure operations run on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(X.device());
int batch_size = X.size(0);
int in_features = X.size(1);
int out_features = Weight.size(0);
auto Y = torch::empty({batch_size, out_features}, X.options());
dim3 block(TILE_N, TILE_M);
dim3 grid((batch_size + TILE_M - 1) / TILE_M, (out_features + TILE_N - 1) / TILE_N);
fused_activation_chain_tiled_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
X.data_ptr<float>(),
Weight.data_ptr<float>(),
Bias.data_ptr<float>(),
Add_value.data_ptr<float>(),
Y.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return Y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused GEMM and activation chain (Swish, Tanh, GELU, Hardtanh) using tiled kernel with shared memory padding and __ldg loads");
}