2.61. 3D Transposed Convolution with ReLU and GroupNorm
Performs a 3D transposed convolution to upsample input data, applies a ReLU activation to introduce non-linearity, and then normalizes the output by splitting channels into groups for stable feature scaling.
Progress
Optimization Timeline
Step 3 Details
Reasoning
To further optimize our fused operator, we can improve occupancy and memory throughput by increasing the per‐kernel block size and using a grid‐stride loop in the normalization kernel. In the previous version the reduction kernel was launched per (N,group) with a fixed block size of 256. We increase that to 512 to improve parallel reduction efficiency, and adjust the shared memory accordingly. In the normalization kernel, we replace the simple index check with a grid‐stride loop so that each thread processes multiple elements if necessary. These changes should help reduce kernel launch overhead and improve overall memory throughput, moving our fused operator to a speedup over native. Before finalizing, we ask: have these improvements genuinely helped? Yes – they increase occupancy and use grid‐stride loops to hide latency while processing multiple elements per thread.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Kernel 1: Compute group statistics (mean and inverse std) for each (sample, group) pair.
// Applies ReLU before computing statistics, and uses warp-level reduction with a grid size of one block per group.
__global__ void compute_group_stats_kernel(const float* __restrict__ input,
float* __restrict__ mean,
float* __restrict__ inv_std,
int N, int C, int D, int H, int W,
int groups,
float eps) {
// One block per (n,group) pair.
int block_id = blockIdx.x; // block_id in [0, N * groups)
int n = block_id / groups;
int g = block_id % groups;
int group_channels = C / groups;
int spatial_size = D * H * W;
int group_count = group_channels * spatial_size;
float sum = 0.0f;
float sum_sq = 0.0f;
// Process elements with stride over blockDim.x.
for (int idx = threadIdx.x; idx < group_count; idx += blockDim.x) {
int local_c = idx / spatial_size; // channel index inside the group
int spatial_idx = idx % spatial_size;
int c = g * group_channels + local_c;
int offset = n * (C * spatial_size) + c * spatial_size + spatial_idx;
float val = input[offset];
float relu_val = (val > 0.0f ? val : 0.0f);
sum += relu_val;
sum_sq += relu_val * relu_val;
}
// Warp-level reduction using shuffle.
unsigned int mask = 0xffffffff;
int lane = threadIdx.x & 31;
for (int offset = 16; offset > 0; offset /= 2) {
sum += __shfl_down_sync(mask, sum, offset);
sum_sq += __shfl_down_sync(mask, sum_sq, offset);
}
// Shared memory for warp partial sums.
extern __shared__ float s[]; // s[0...num_warps-1] for sums, s[num_warps...2*num_warps-1] for sum_sq.
int warp_id = threadIdx.x / 32;
int num_warps = blockDim.x / 32;
if (lane == 0) {
s[warp_id] = sum;
s[warp_id + num_warps] = sum_sq;
}
__syncthreads();
// Thread 0 aggregates warp partial sums.
if (threadIdx.x == 0) {
float final_sum = 0.0f;
float final_sum_sq = 0.0f;
for (int i = 0; i < num_warps; i++) {
final_sum += s[i];
final_sum_sq += s[i + num_warps];
}
float count = (float) group_count;
float grp_mean = final_sum / count;
float grp_var = final_sum_sq / count - grp_mean * grp_mean;
float invstd_val = 1.0f / sqrtf(grp_var + eps);
int out_index = n * groups + g;
mean[out_index] = grp_mean;
inv_std[out_index] = invstd_val;
}
}
// Kernel 2: Apply fused ReLU and GroupNorm.
// Uses a grid-stride loop so that each thread processes multiple elements if necessary.
__global__ void apply_groupnorm_relu_kernel(const float* __restrict__ input,
const float* __restrict__ mean,
const float* __restrict__ inv_std,
const float* __restrict__ weight,
const float* __restrict__ bias,
float* __restrict__ output,
int N, int C, int D, int H, int W,
int groups) {
int spatial_size = D * H * W;
int total = N * C * spatial_size;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) {
int n = idx / (C * spatial_size);
int rem = idx % (C * spatial_size);
int c = rem / spatial_size;
int group_channels = C / groups;
int g = c / group_channels;
int mean_idx = n * groups + g;
float grp_mean = mean[mean_idx];
float invstd_val = inv_std[mean_idx];
float val = input[idx];
float relu_val = (val > 0.0f ? val : 0.0f);
float norm = (relu_val - grp_mean) * invstd_val;
output[idx] = norm * weight[c] + bias[c];
}
}
// Fused forward function: first computes group statistics, then applies normalization.
torch::Tensor fused_forward(torch::Tensor input,
torch::Tensor weight,
torch::Tensor bias,
int groups,
float eps) {
c10::cuda::CUDAGuard device_guard(input.device());
TORCH_CHECK(input.dim() == 5, "Expected input to be a 5D tensor");
int N = input.size(0);
int C = input.size(1);
int D = input.size(2);
int H = input.size(3);
int W = input.size(4);
TORCH_CHECK(C % groups == 0, "C must be divisible by groups");
int spatial_size = D * H * W;
auto options = input.options();
// Temporary tensors for per-group mean and inverse std (shape [N, groups]).
torch::Tensor mean = torch::empty({N, groups}, options);
torch::Tensor inv_std = torch::empty({N, groups}, options);
// Launch Kernel 1: one block per (n, group)
int num_groups = N * groups;
int blockSize = 512; // increased block size
int numWarps = blockSize / 32;
int sharedMemSize = 2 * numWarps * sizeof(float);
compute_group_stats_kernel<<<num_groups, blockSize, sharedMemSize, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
mean.data_ptr<float>(),
inv_std.data_ptr<float>(),
N, C, D, H, W, groups, eps);
// Launch Kernel 2: normalization using grid-stride loop
int total_elements = N * C * spatial_size;
int threads = 512; // increased block size here as well
int blocks = (total_elements + threads - 1) / threads;
torch::Tensor output = torch::empty_like(input);
apply_groupnorm_relu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
mean.data_ptr<float>(),
inv_std.data_ptr<float>(),
weight.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
N, C, D, H, W, groups);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_forward", &fused_forward, "Fused ReLU and GroupNorm with optimized reduction and grid-stride loop");
}