2.12. Gemm Multiply LeakyReLU Operation
Calculates a matrix multiplication on the input data, scales the result by a constant multiplier, and then computes the LeakyReLU activation using a configurable negative slope.
Progress
Step 4 of 4
Reward
1.46x
Optimization Timeline
Step 4 Details
Reasoning
We can further optimize our tiled kernel by using the read-only __ldg intrinsic when loading values from global memory into shared memory. This change ensures that the data is fetched from the read–only cache, reducing load latency. We simply add __ldg() calls in the shared memory loading parts of the kernel. The rest of the tiled kernel remains unchanged. The Python module remains the same; it just loads our updated CUDA extension. This improvement genuinely helps by reducing memory latency and increasing throughput without altering the algorithm.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#define TILE_SIZE 16
// Fused kernel using tiling, shared memory, and __ldg for efficient global memory access.
// Computes output: y[b, j] = LeakyReLU( (∑ₖ x[b,k]*W[j,k] + bias[j]) * multiplier )
// x: [batch_size, in_features]
// W: [out_features, in_features] (row-major, as in typical Linear layer)
// bias: [out_features]
// y: [batch_size, out_features]
__global__ void fused_forward_tiled_kernel(const float* __restrict__ x,
const float* __restrict__ W,
const float* __restrict__ bias,
float multiplier,
float negative_slope,
float* __restrict__ y,
int batch_size,
int in_features,
int out_features) {
int row = blockIdx.y * TILE_SIZE + threadIdx.y; // batch index
int col = blockIdx.x * TILE_SIZE + threadIdx.x; // output feature index
__shared__ float shared_x[TILE_SIZE][TILE_SIZE];
__shared__ float shared_W[TILE_SIZE][TILE_SIZE];
float sum = 0.0f;
int numTiles = (in_features + TILE_SIZE - 1) / TILE_SIZE;
for (int t = 0; t < numTiles; t++) {
int col_index = t * TILE_SIZE + threadIdx.x;
if (row < batch_size && col_index < in_features)
shared_x[threadIdx.y][threadIdx.x] = __ldg(&x[row * in_features + col_index]);
else
shared_x[threadIdx.y][threadIdx.x] = 0.0f;
int row_index = t * TILE_SIZE + threadIdx.y;
if (col < out_features && row_index < in_features)
shared_W[threadIdx.y][threadIdx.x] = __ldg(&W[col * in_features + row_index]);
else
shared_W[threadIdx.y][threadIdx.x] = 0.0f;
__syncthreads();
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
sum += shared_x[threadIdx.y][i] * shared_W[i][threadIdx.x];
}
__syncthreads();
}
if (row < batch_size && col < out_features) {
sum += bias[col];
sum *= multiplier;
float result = (sum >= 0.0f) ? sum : sum * negative_slope;
y[row * out_features + col] = result;
}
}
torch::Tensor fused_forward_tiled(torch::Tensor x,
torch::Tensor W,
torch::Tensor bias,
float multiplier,
float negative_slope) {
c10::cuda::CUDAGuard device_guard(x.device());
int batch_size = x.size(0);
int in_features = x.size(1);
int out_features = W.size(0);
auto y = torch::empty({batch_size, out_features}, x.options());
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid((out_features + TILE_SIZE - 1) / TILE_SIZE,
(batch_size + TILE_SIZE - 1) / TILE_SIZE);
fused_forward_tiled_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
W.data_ptr<float>(),
bias.data_ptr<float>(),
multiplier,
negative_slope,
y.data_ptr<float>(),
batch_size,
in_features,
out_features
);
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_forward_tiled", &fused_forward_tiled, "Fused GEMM, multiplier, and LeakyReLU using tiled shared memory and __ldg");
}