2.88. GEMM, GroupNorm, and Dual Swish Multiplication
Calculates a sequence of operations that begins with a general matrix multiplication and then normalizes the intermediate output by grouping features. It then applies a Swish activation, scales the result element‐wise using a learnable weight, and finally applies a second Swish activation to produce the final output.
Progress
Optimization Timeline
Step 6 Details
Reasoning
We can further optimize our fused kernel by using the __ldg intrinsic in the non‐vectorized kernel when loading the weight value. Although the vectorized kernel already uses __ldg to fetch the weight vector, the non‐vectorized kernel currently fetches the weight without ldg. Using __ldg here can improve cache efficiency for these random weight accesses. This is a genuine micro‐optimization and not a hack, as it properly uses the read-only cache on devices where it’s beneficial. The rest of the code is already optimized by employing inline sigmoid functions, vectorized processing when possible, and using c10::cuda utilities for stream and device management. We now provide the updated code with this extra improvement in the non‐vectorized kernel.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Forced inline fast sigmoid function.
__device__ __forceinline__ float sigmoidf(float z) {
return 1.0f / (1.0f + __expf(-z));
}
// Non-vectorized fused elementwise kernel: each thread processes one element.
__global__ void fused_elementwise_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
float* __restrict__ y,
int total,
int out_features) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = idx; i < total; i += stride) {
float a = x[i];
float t = a * sigmoidf(a);
int col = i % out_features;
// Use __ldg to load the weight from read-only cache.
float weight_val = __ldg(&weight[col]);
float t2 = t * weight_val;
y[i] = t2 * sigmoidf(t2);
}
}
// Vectorized fused elementwise kernel using float4.
// Assumes out_features is divisible by 4.
__global__ void fused_elementwise_vec_kernel(const float4* __restrict__ x,
const float4* __restrict__ weight,
float4* __restrict__ y,
int total_vec,
int out_features) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int vec_out = out_features / 4; // Number of float4 elements per row.
for (int i = idx; i < total_vec; i += stride) {
int row = i / vec_out;
int vec_col = i % vec_out; // corresponds to columns: vec_col*4, vec_col*4+1, ...
// Load vectorized data using read-only cache.
float4 a_vec = __ldg(&x[i]);
// Compute first swish activation: t = a * sigmoid(a)
float t0 = a_vec.x * sigmoidf(a_vec.x);
float t1 = a_vec.y * sigmoidf(a_vec.y);
float t2 = a_vec.z * sigmoidf(a_vec.z);
float t3 = a_vec.w * sigmoidf(a_vec.w);
// Load corresponding weight vector using read-only cache.
const float4* weight_vec = weight;
float4 w_vec = __ldg(&weight_vec[vec_col]);
// Multiply elementwise.
float r0 = t0 * w_vec.x;
float r1 = t1 * w_vec.y;
float r2 = t2 * w_vec.z;
float r3 = t3 * w_vec.w;
// Second swish activation: y = r * sigmoid(r)
float4 res;
res.x = r0 * sigmoidf(r0);
res.y = r1 * sigmoidf(r1);
res.z = r2 * sigmoidf(r2);
res.w = r3 * sigmoidf(r3);
y[i] = res;
}
}
torch::Tensor fused(torch::Tensor x, torch::Tensor weight) {
// x: [batch, out_features]
// weight: [out_features]
c10::cuda::CUDAGuard device_guard(x.device());
TORCH_CHECK(x.dim() == 2, "x must be a 2D tensor");
TORCH_CHECK(weight.dim() == 1, "weight must be a 1D tensor");
int batch = x.size(0);
int out_features = x.size(1);
TORCH_CHECK(weight.size(0) == out_features, "weight size must equal x's second dimension");
int total = batch * out_features;
auto y = torch::empty({batch, out_features}, x.options());
// Use vectorized kernel if out_features is divisible by 4.
if (out_features % 4 == 0) {
int total_vec = total / 4;
const float4* x_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
float4* y_vec = reinterpret_cast<float4*>(y.data_ptr<float>());
const float4* weight_vec = reinterpret_cast<const float4*>(weight.data_ptr<float>());
int threads = 256;
int blocks = (total_vec + threads - 1) / threads;
fused_elementwise_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x_vec, weight_vec, y_vec, total_vec, out_features);
} else {
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_elementwise_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), weight.data_ptr<float>(), y.data_ptr<float>(), total, out_features);
}
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused", &fused, "Fused elementwise operations: first swish, multiply, second swish (vectorized version included)");
}