2.64. Matrix Multiplication and Activation Chain
Calculates a matrix multiplication on input data, applies a log-sum-exp reduction along the feature dimension, and sequentially applies two LeakyReLU activations followed by two GELU activations.
Progress
Optimization Timeline
Step 6 Details
Reasoning
We can try to further optimize our fast-path fused kernel by adding a few micro‐optimizations, such as marking pointer arguments as __restrict__ to help the compiler with alias-analysis and reinforcing loop unrolling. However, since our current implementation already obtains a slight speedup (reward ~1.0035, which is acceptable), these micro‐optimizations might only produce marginal improvements. Nonetheless, we add __restrict__ qualifiers to our memory pointers and add more comments for clarity, without changing the overall algorithm. Our methodology remains the same: 1. If out_features equals exactly 2*blockDim.x (typical for 512 outputs when blockDim.x is 256), use the fast path where each thread processes two outputs using vectorized loads. 2. Otherwise, fall back to a general GEMM loop. 3. After GEMM, we perform two warp‐level reductions (one for maximum and one for summing exp(s_y – m)) using __shfl_down_sync, and store the final reduction values in shared memory so all threads can see the result. 4. We then compute log-sum-exp, and apply two successive LeakyReLU and two GELU activations. Below is the final updated Python module (unchanged in structure) and the CUDA kernel with added __restrict__ qualifiers and extra comments for clarity.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
#define BLOCK_THREADS 256
#define WARP_SIZE 32
// Device function: Warp-level reduction to compute maximum.
__inline__ __device__ float warpReduceMax(float val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
// Use full mask 0xffffffff since we operate on a full warp.
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
return val;
}
// Device function: Warp-level reduction to compute sum.
__inline__ __device__ float warpReduceSum(float val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
// Device function for LeakyReLU activation.
__device__ inline float leaky_relu(float x, float slope) {
return (x >= 0.0f) ? x : slope * x;
}
// Device function for GELU activation (approximation).
__device__ inline float gelu(float x) {
const float sqrt_2_over_pi = 0.7978845608028654f;
return 0.5f * x * (1.0f + tanhf(sqrt_2_over_pi * (x + 0.044715f * x * x * x)));
}
// Fused kernel: each block processes one sample.
// __restrict__ qualifiers added to help compiler with aliasing.
__global__ void fused_gemm_lse_act_kernel(const float* __restrict__ x, // [batch, in_features]
const float* __restrict__ weight, // [out_features, in_features]
const float* __restrict__ bias, // [out_features]
float* __restrict__ output, // [batch, 1]
int in_features,
int out_features) {
int sample = blockIdx.x;
const float* x_sample = x + sample * in_features;
// Dynamic shared memory for GEMM results: s_y array of size out_features.
extern __shared__ float s_y[];
// Shared memory for final reduction results.
__shared__ float final_m; // Final maximum value.
__shared__ float final_sum; // Final sum of exp(s_y - m).
// --- GEMM Phase ---
// Fast path: if out_features equals exactly 2 * blockDim.x, use fast computation.
if (out_features == blockDim.x * 2) {
int j0 = threadIdx.x;
int j1 = threadIdx.x + blockDim.x;
float dot0 = 0.0f, dot1 = 0.0f;
if ((in_features & 3) == 0) { // if in_features is divisible by 4, use vectorized loads.
int vec_count = in_features >> 2; // in_features/4
const float4* x_vec = reinterpret_cast<const float4*>(x_sample);
const float4* w_vec0 = reinterpret_cast<const float4*>(weight + j0 * in_features);
const float4* w_vec1 = reinterpret_cast<const float4*>(weight + j1 * in_features);
#pragma unroll
for (int k = 0; k < vec_count; k++) {
float4 xv = x_vec[k];
float4 wv0 = w_vec0[k];
float4 wv1 = w_vec1[k];
dot0 += xv.x * wv0.x + xv.y * wv0.y + xv.z * wv0.z + xv.w * wv0.w;
dot1 += xv.x * wv1.x + xv.y * wv1.y + xv.z * wv1.z + xv.w * wv1.w;
}
} else {
for (int k = 0; k < in_features; k++) {
dot0 += x_sample[k] * weight[j0 * in_features + k];
dot1 += x_sample[k] * weight[j1 * in_features + k];
}
}
dot0 += bias[j0];
dot1 += bias[j1];
s_y[j0] = dot0;
s_y[j1] = dot1;
} else {
// Fallback general GEMM loop.
for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
float dot = 0.0f;
const float* w_row = weight + j * in_features;
if ((in_features & 3) == 0) {
int vec_count = in_features >> 2;
const float4* x_vec = reinterpret_cast<const float4*>(x_sample);
const float4* w_vec = reinterpret_cast<const float4*>(w_row);
#pragma unroll
for (int k = 0; k < vec_count; k++) {
float4 xv = x_vec[k];
float4 wv = w_vec[k];
dot += xv.x * wv.x + xv.y * wv.y + xv.z * wv.z + xv.w * wv.w;
}
} else {
for (int k = 0; k < in_features; k++) {
dot += x_sample[k] * w_row[k];
}
}
dot += bias[j];
s_y[j] = dot;
}
}
__syncthreads();
// --- Maximum Reduction ---
float thread_max = -INFINITY;
for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
thread_max = fmaxf(thread_max, s_y[j]);
}
float local_max = warpReduceMax(thread_max);
int lane = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
__shared__ float s_warp_max[BLOCK_THREADS / WARP_SIZE];
if (lane == 0) {
s_warp_max[warp_id] = local_max;
}
__syncthreads();
float final_max_val = -INFINITY;
if (threadIdx.x < (blockDim.x / WARP_SIZE)) {
final_max_val = s_warp_max[threadIdx.x];
}
final_max_val = warpReduceMax(final_max_val);
if (threadIdx.x == 0) {
final_m = final_max_val; // Store final max in shared memory.
}
__syncthreads();
float m = final_m; // Broadcast final max.
// --- Sum Reduction ---
float thread_sum = 0.0f;
for (int j = threadIdx.x; j < out_features; j += blockDim.x) {
thread_sum += expf(s_y[j] - m);
}
float local_sum = warpReduceSum(thread_sum);
__shared__ float s_warp_sum[BLOCK_THREADS / WARP_SIZE];
if (lane == 0) {
s_warp_sum[warp_id] = local_sum;
}
__syncthreads();
float final_sum_val = 0.0f;
if (threadIdx.x < (blockDim.x / WARP_SIZE)) {
final_sum_val = s_warp_sum[threadIdx.x];
}
final_sum_val = warpReduceSum(final_sum_val);
if (threadIdx.x == 0) {
final_sum = final_sum_val; // Store final sum in shared memory.
}
__syncthreads();
float s_final = final_sum; // Final sum value.
float lse = m + logf(s_final);
// --- Activation Chain ---
float act = lse;
act = leaky_relu(act, 0.01f);
act = leaky_relu(act, 0.01f);
act = gelu(act);
act = gelu(act);
// Write the final output for this sample.
if (threadIdx.x == 0) {
output[sample] = act;
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor weight, torch::Tensor bias,
int in_features, int out_features) {
c10::cuda::CUDAGuard device_guard(x.device());
int batch_size = x.size(0);
auto output = torch::empty({batch_size, 1}, x.options());
int threads = BLOCK_THREADS;
int blocks = batch_size;
// Allocate dynamic shared memory for GEMM results: out_features floats.
size_t shared_mem_size = out_features * sizeof(float);
fused_gemm_lse_act_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
in_features,
out_features
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused GEMM, LogSumExp, double LeakyReLU, double GELU with warp-level reduction and fast GEMM path");
}