2.10. Transposed Convolution with Pooling and Activation Pipeline

Performs a transposed convolution on 2D input data to upsample it, then applies max pooling to reduce spatial dimensions, uses a HardTanh activation to bound the values, computes the spatial mean, and finally applies a Tanh activation for normalization.

Progress

Step 4 of 4
Reward
1.22x

Optimization Timeline

Step 4 Details

Reasoning

We can further optimize the fused kernel by specializing the inner loop for the common case when pool_kernel equals 2. In this branch we can precompute the appropriate row offsets to avoid redundant multiplications and then unroll the four value loads manually using the __ldg intrinsic for read‐only global memory loads. This minimizes the overhead per pooling window. The rest of the kernel (warp-level reduction using __shfl_down_sync, shared memory reduction) remains unchanged to ensure efficient accumulation among threads. The Python module remains the same.

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <math.h>
#include <c10/cuda/CUDAGuard.h>

// Define constant for warp size.
constexpr int WARP_SIZE = 32;

// Fused kernel to perform max pooling, apply Hardtanh activation,
// compute the spatial mean, and then apply tanh.
// Input x: shape [B, C, H, W]
// Output: shape [B, C, 1, 1]
__global__ void fused_pool_act_kernel(const float* __restrict__ x,
                                        float* __restrict__ output,
                                        int B, int C, int H, int W,
                                        int pool_kernel, int pool_stride,
                                        float hardtanh_min, float hardtanh_max) {
  // Each block processes one (batch, channel) pair.
  int b = blockIdx.x;
  int c = blockIdx.y;
  // Pointer for the (b, c) channel.
  const float* x_channel = x + (b * C + c) * H * W;

  // Compute output dimensions for pooling.
  int H_pool = (H - pool_kernel) / pool_stride + 1;
  int W_pool = (W - pool_kernel) / pool_stride + 1;
  int total_windows = H_pool * W_pool;

  int tid = threadIdx.x;
  float local_sum = 0.0f;

  // Loop over pooling windows assigned to this thread.
  for (int idx = tid; idx < total_windows; idx += blockDim.x) {
    int i = idx / W_pool;
    int j = idx % W_pool;
    int start_i = i * pool_stride;
    int start_j = j * pool_stride;
    float window_max = -FLT_MAX;
    // Specialized branch for pool_kernel == 2.
    if (pool_kernel == 2) {
      int row0 = start_i;
      int row1 = start_i + 1;
      // Precompute row offsets.
      int offset0 = row0 * W;
      int offset1 = row1 * W;
      float v0 = (row0 < H && start_j < W) ? __ldg(&x_channel[offset0 + start_j]) : -FLT_MAX;
      float v1 = (row0 < H && (start_j + 1) < W) ? __ldg(&x_channel[offset0 + start_j + 1]) : -FLT_MAX;
      float v2 = (row1 < H && start_j < W) ? __ldg(&x_channel[offset1 + start_j]) : -FLT_MAX;
      float v3 = (row1 < H && (start_j + 1) < W) ? __ldg(&x_channel[offset1 + start_j + 1]) : -FLT_MAX;
      window_max = fmaxf(fmaxf(v0, v1), fmaxf(v2, v3));
    } else {
      // Generic loop for other pool_kernel sizes.
      for (int di = 0; di < pool_kernel; di++) {
        for (int dj = 0; dj < pool_kernel; dj++) {
          int cur_i = start_i + di;
          int cur_j = start_j + dj;
          if (cur_i < H && cur_j < W) {
            float val = __ldg(&x_channel[cur_i * W + cur_j]);
            if (val > window_max) {
              window_max = val;
            }
          }
        }
      }
    }
    // Apply Hardtanh clamping.
    if (window_max < hardtanh_min)
      window_max = hardtanh_min;
    if (window_max > hardtanh_max)
      window_max = hardtanh_max;
    local_sum += window_max;
  }

  // Warp-level reduction using __shfl_down_sync.
  unsigned int mask = 0xffffffff;
  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
      local_sum += __shfl_down_sync(mask, local_sum, offset);
  }
  
  // Shared memory for warp sums.
  extern __shared__ float warpSums[];
  int lane = tid % WARP_SIZE;
  int warpId = tid / WARP_SIZE;
  if (lane == 0) {
    warpSums[warpId] = local_sum;
  }
  __syncthreads();

  // Final reduction by first warp.
  float block_sum = 0.0f;
  int numWarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
  if (tid < numWarps) {
    block_sum = warpSums[tid];
  }
  __syncthreads();
  if (tid < WARP_SIZE) {
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
      block_sum += __shfl_down_sync(mask, block_sum, offset);
    }
    if (tid == 0) {
      float mean_val = block_sum / total_windows;
      // Apply final tanh activation.
      output[b * C + c] = tanhf(mean_val);
    }
  }
}

// Wrapper function to launch the fused kernel.
torch::Tensor fused_pool_activation(torch::Tensor x,
                                      int pool_kernel,
                                      int pool_stride,
                                      float hardtanh_min,
                                      float hardtanh_max) {
  // Ensure correct CUDA device context.
  c10::cuda::CUDAGuard device_guard(x.device());
  
  // x has shape [B, C, H, W].
  int B = x.size(0);
  int C = x.size(1);
  int H = x.size(2);
  int W = x.size(3);
  
  // Allocate output tensor of shape [B, C, 1, 1].
  auto output = torch::empty({B, C, 1, 1}, x.options());
  
  // Launch one block per (B, C) pair.
  dim3 grid(B, C);
  int threads = 256;
  int numWarps = (threads + WARP_SIZE - 1) / WARP_SIZE;
  size_t shared_mem = numWarps * sizeof(float);
  
  fused_pool_act_kernel<<<grid, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
      x.data_ptr<float>(), output.data_ptr<float>(), B, C, H, W,
      pool_kernel, pool_stride, hardtanh_min, hardtanh_max);
  
  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("fused_pool_activation", &fused_pool_activation, "Fused maxpool + Hardtanh + mean + tanh");
}