2.10. Transposed Convolution with Pooling and Activation Pipeline
Performs a transposed convolution on 2D input data to upsample it, then applies max pooling to reduce spatial dimensions, uses a HardTanh activation to bound the values, computes the spatial mean, and finally applies a Tanh activation for normalization.
Progress
Optimization Timeline
Step 4 Details
Reasoning
We can further optimize the fused kernel by specializing the inner loop for the common case when pool_kernel equals 2. In this branch we can precompute the appropriate row offsets to avoid redundant multiplications and then unroll the four value loads manually using the __ldg intrinsic for read‐only global memory loads. This minimizes the overhead per pooling window. The rest of the kernel (warp-level reduction using __shfl_down_sync, shared memory reduction) remains unchanged to ensure efficient accumulation among threads. The Python module remains the same.
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <math.h>
#include <c10/cuda/CUDAGuard.h>
// Define constant for warp size.
constexpr int WARP_SIZE = 32;
// Fused kernel to perform max pooling, apply Hardtanh activation,
// compute the spatial mean, and then apply tanh.
// Input x: shape [B, C, H, W]
// Output: shape [B, C, 1, 1]
__global__ void fused_pool_act_kernel(const float* __restrict__ x,
float* __restrict__ output,
int B, int C, int H, int W,
int pool_kernel, int pool_stride,
float hardtanh_min, float hardtanh_max) {
// Each block processes one (batch, channel) pair.
int b = blockIdx.x;
int c = blockIdx.y;
// Pointer for the (b, c) channel.
const float* x_channel = x + (b * C + c) * H * W;
// Compute output dimensions for pooling.
int H_pool = (H - pool_kernel) / pool_stride + 1;
int W_pool = (W - pool_kernel) / pool_stride + 1;
int total_windows = H_pool * W_pool;
int tid = threadIdx.x;
float local_sum = 0.0f;
// Loop over pooling windows assigned to this thread.
for (int idx = tid; idx < total_windows; idx += blockDim.x) {
int i = idx / W_pool;
int j = idx % W_pool;
int start_i = i * pool_stride;
int start_j = j * pool_stride;
float window_max = -FLT_MAX;
// Specialized branch for pool_kernel == 2.
if (pool_kernel == 2) {
int row0 = start_i;
int row1 = start_i + 1;
// Precompute row offsets.
int offset0 = row0 * W;
int offset1 = row1 * W;
float v0 = (row0 < H && start_j < W) ? __ldg(&x_channel[offset0 + start_j]) : -FLT_MAX;
float v1 = (row0 < H && (start_j + 1) < W) ? __ldg(&x_channel[offset0 + start_j + 1]) : -FLT_MAX;
float v2 = (row1 < H && start_j < W) ? __ldg(&x_channel[offset1 + start_j]) : -FLT_MAX;
float v3 = (row1 < H && (start_j + 1) < W) ? __ldg(&x_channel[offset1 + start_j + 1]) : -FLT_MAX;
window_max = fmaxf(fmaxf(v0, v1), fmaxf(v2, v3));
} else {
// Generic loop for other pool_kernel sizes.
for (int di = 0; di < pool_kernel; di++) {
for (int dj = 0; dj < pool_kernel; dj++) {
int cur_i = start_i + di;
int cur_j = start_j + dj;
if (cur_i < H && cur_j < W) {
float val = __ldg(&x_channel[cur_i * W + cur_j]);
if (val > window_max) {
window_max = val;
}
}
}
}
}
// Apply Hardtanh clamping.
if (window_max < hardtanh_min)
window_max = hardtanh_min;
if (window_max > hardtanh_max)
window_max = hardtanh_max;
local_sum += window_max;
}
// Warp-level reduction using __shfl_down_sync.
unsigned int mask = 0xffffffff;
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
local_sum += __shfl_down_sync(mask, local_sum, offset);
}
// Shared memory for warp sums.
extern __shared__ float warpSums[];
int lane = tid % WARP_SIZE;
int warpId = tid / WARP_SIZE;
if (lane == 0) {
warpSums[warpId] = local_sum;
}
__syncthreads();
// Final reduction by first warp.
float block_sum = 0.0f;
int numWarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
if (tid < numWarps) {
block_sum = warpSums[tid];
}
__syncthreads();
if (tid < WARP_SIZE) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
block_sum += __shfl_down_sync(mask, block_sum, offset);
}
if (tid == 0) {
float mean_val = block_sum / total_windows;
// Apply final tanh activation.
output[b * C + c] = tanhf(mean_val);
}
}
}
// Wrapper function to launch the fused kernel.
torch::Tensor fused_pool_activation(torch::Tensor x,
int pool_kernel,
int pool_stride,
float hardtanh_min,
float hardtanh_max) {
// Ensure correct CUDA device context.
c10::cuda::CUDAGuard device_guard(x.device());
// x has shape [B, C, H, W].
int B = x.size(0);
int C = x.size(1);
int H = x.size(2);
int W = x.size(3);
// Allocate output tensor of shape [B, C, 1, 1].
auto output = torch::empty({B, C, 1, 1}, x.options());
// Launch one block per (B, C) pair.
dim3 grid(B, C);
int threads = 256;
int numWarps = (threads + WARP_SIZE - 1) / WARP_SIZE;
size_t shared_mem = numWarps * sizeof(float);
fused_pool_act_kernel<<<grid, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), output.data_ptr<float>(), B, C, H, W,
pool_kernel, pool_stride, hardtanh_min, hardtanh_max);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_pool_activation", &fused_pool_activation, "Fused maxpool + Hardtanh + mean + tanh");
}