2.19. Transposed Convolution with GELU and Group Normalization
Performs a 2D transposed convolution to upsample the input tensor, calculates the Gaussian Error Linear Unit (GELU) activation using a tanh-based approximation to introduce non-linearity, and applies group normalization on the output channels.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize our fused operator by merging the two kernel launches (one for GELU and one for GroupNorm) into a single kernel. Instead of writing the GELU output to a temporary buffer and then normalizing it in a separate kernel, we can fuse these steps in one kernel. For each (sample, group) pair (each block), we first compute the GELU activation for all elements in that group while accumulating partial sums (and sums-of-squares) to compute the group mean and variance. Then, after a shared-memory reduction, we recompute GELU for each element (a small overhead compared to global memory access) and write its normalized value. This fusion reduces intermediate global memory traffic and kernel launch overhead, thus improving runtime. Before finishing, we verify that our improvements really target memory bandwidth optimization and minimize overhead without “cheating” by ignoring necessary computations—the double computation of GELU is acceptable given the overall savings in memory access. Below are the final updated Python and CUDA modules.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel: Computes GELU activation and performs GroupNorm normalization.
// Each block is assigned to one (sample, group) pair.
__global__ void fused_gelu_groupnorm_kernel(const float* __restrict__ input,
float* __restrict__ output,
int N, int C, int H, int W,
int num_groups,
int group_size) {
// Determine which sample and which group this block is handling.
int blockId = blockIdx.x; // Range: 0 ... N*num_groups - 1
int n = blockId / num_groups; // Sample index.
int g = blockId % num_groups; // Group index.
int group_channels = C / num_groups;
// Calculate the base index for this group in the input tensor.
// Tensor layout: NCHW.
int base = n * C * H * W + g * group_channels * H * W;
// Each thread will process multiple elements in the group.
float partial_sum = 0.0f;
float partial_sum_sq = 0.0f;
// First pass: Compute GELU activation and accumulate sums for mean and variance.
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
float x_val = input[base + i];
// Approximate GELU: gelu(x) = 0.5 * x * (1 + tanh(√(2/π)*(x + 0.044715*x^3))).
float c = 0.044715f;
float sqrt_2_over_pi = 0.7978845608028654f;
float x_cube = x_val * x_val * x_val;
float tanh_arg = sqrt_2_over_pi * (x_val + c * x_cube);
float gelu_val = 0.5f * x_val * (1.0f + tanhf(tanh_arg));
partial_sum += gelu_val;
partial_sum_sq += gelu_val * gelu_val;
}
// Use shared memory for reduction.
extern __shared__ float sdata[]; // size: 2 * blockDim.x floats.
float* shared_sum = sdata;
float* shared_sum_sq = sdata + blockDim.x;
shared_sum[threadIdx.x] = partial_sum;
shared_sum_sq[threadIdx.x] = partial_sum_sq;
__syncthreads();
// Parallel reduction in shared memory for sum and sum of squares.
for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
if (threadIdx.x < stride) {
shared_sum[threadIdx.x] += shared_sum[threadIdx.x + stride];
shared_sum_sq[threadIdx.x] += shared_sum_sq[threadIdx.x + stride];
}
__syncthreads();
}
// Now compute mean and inverse standard deviation.
float mean = shared_sum[0] / group_size;
float variance = shared_sum_sq[0] / group_size - mean * mean;
float inv_std = rsqrtf(variance + 1e-5f);
// Second pass: Recompute GELU activation and normalize.
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
float x_val = input[base + i];
float c = 0.044715f;
float sqrt_2_over_pi = 0.7978845608028654f;
float x_cube = x_val * x_val * x_val;
float tanh_arg = sqrt_2_over_pi * (x_val + c * x_cube);
float gelu_val = 0.5f * x_val * (1.0f + tanhf(tanh_arg));
output[base + i] = (gelu_val - mean) * inv_std;
}
}
// Fused forward function: calls the fused kernel.
torch::Tensor forward(torch::Tensor x, int num_groups) {
// Ensure correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// Get tensor dimensions (assumed NCHW layout).
int N = x.size(0);
int C = x.size(1);
int H = x.size(2);
int W = x.size(3);
// Validate that C is divisible by num_groups.
TORCH_CHECK(C % num_groups == 0, "Channel dimension must be divisible by num_groups");
// Compute group-specific parameters.
int group_channels = C / num_groups;
int group_size = group_channels * H * W; // number of elements per group.
// Allocate output tensor.
auto output = torch::empty_like(x);
// Launch one block per (n, group) pair.
int blocks = N * num_groups;
int threads = 256;
size_t shared_mem_size = 2 * threads * sizeof(float);
fused_gelu_groupnorm_kernel<<<blocks, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
output.data_ptr<float>(),
N, C, H, W,
num_groups,
group_size);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused GELU and GroupNorm (combined in a single kernel) for ConvTranspose2d outputs");
}