2.96. 3D Transposed Convolution with Pooling and Clamping
Performs a 3D transposed convolution on an input tensor, scales the resulting feature maps by a constant factor, applies a 3D max pooling operation, consolidates spatial information via global average pooling, and clamps the output values to a specified range.
Progress
Optimization Timeline
Step 3 Details
Reasoning
The fused operations after the conv_transpose involve scaled 3D max pooling (with window size pool_k×pool_k×pool_k) followed by global averaging and clamping. In our earlier version each thread (assigned to one (n,c) pair) serially loops over all pooling windows using triple nested loops. Because the common value for pool_k is 2 (i.e. pool window size is 2×2×2), we can optimize by unrolling the innermost loops when pool_k==2. This minimizes loop overhead and uses inline fmaxf calls to compute the maximum within each pooling window. For other pool_k values we fall back to the generic loop. We still perform per-block reduction using shared memory, but now the per-window computation is faster with loop-unrolling for pool_k==2. The Python module remains essentially the same, loading the fused operator. Our new CUDA kernel may considerably reduce the runtime when pool_k==2, providing a higher speedup. Before finishing, I ask: Have these improvements genuinely helped? Yes, unrolling for the common case (pool_k==2) reduces the number of iterations in the pooling window from three nested loops (8 iterations) to a sequence of 8 inline operations. This reduces both loop overhead and potential branch mispredictions and should increase throughput. Now I provide the final revised code.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <float.h>
#include <math.h>
// Optimized fused kernel to apply the following operations on input tensor x of shape [N, C, D, H, W]:
// 1. Multiply each element by 'scale'.
// 2. Apply 3D max pooling with non-overlapping windows of size (pool_k x pool_k x pool_k).
// 3. Compute the global average of the maximum values from all pooling windows.
// 4. Clamp the result to the range [0, 1].
// Each block processes one (n, c) pair; the output is a single float per (n, c).
// The output tensor shape is [N, C, 1, 1, 1].
//
// When pool_k == 2, the three nested loops (over kd, kh, kw) are unrolled.
__global__ void fused_post_ops_optimized_kernel(const float* __restrict__ input,
float* __restrict__ output,
int N, int C, int D, int H, int W,
float scale, int pool_k) {
// Each block processes one (n, c) pair.
int idx = blockIdx.x; // idx ranges from 0 to N*C - 1.
int n = idx / C;
int c = idx % C;
// Compute pooled dimensions. Assume D, H, W are divisible by pool_k.
int D_pool = D / pool_k;
int H_pool = H / pool_k;
int W_pool = W / pool_k;
int num_windows = D_pool * H_pool * W_pool;
// Shared memory for per-block reduction.
extern __shared__ float shmem[];
int tid = threadIdx.x;
float local_sum = 0.0f;
// Total number of pooling windows.
int total_windows = num_windows;
// Each thread processes windows in a strided manner.
for (int p = tid; p < total_windows; p += blockDim.x) {
// Convert linear index p into 3D pooling window indices.
int w_idx = p % W_pool;
int temp = p / W_pool;
int h_idx = temp % H_pool;
int d_idx = temp / H_pool;
int d_start = d_idx * pool_k;
int h_start = h_idx * pool_k;
int w_start = w_idx * pool_k;
float window_max = -FLT_MAX;
if (pool_k == 2) {
// Unroll the loop for pool_k == 2 (8 iterations).
int d0 = d_start;
int d1 = d_start + 1;
int h0 = h_start;
int h1 = h_start + 1;
int w0 = w_start;
int w1 = w_start + 1;
int base = ((n * C + c) * D);
// Compute linear indices for each element in the 2x2x2 window.
int idx0 = (((base + d0) * H) + h0) * W + w0;
int idx1 = (((base + d0) * H) + h0) * W + w1;
int idx2 = (((base + d0) * H) + h1) * W + w0;
int idx3 = (((base + d0) * H) + h1) * W + w1;
int idx4 = (((base + d1) * H) + h0) * W + w0;
int idx5 = (((base + d1) * H) + h0) * W + w1;
int idx6 = (((base + d1) * H) + h1) * W + w0;
int idx7 = (((base + d1) * H) + h1) * W + w1;
float v0 = input[idx0] * scale;
float v1 = input[idx1] * scale;
float v2 = input[idx2] * scale;
float v3 = input[idx3] * scale;
float v4 = input[idx4] * scale;
float v5 = input[idx5] * scale;
float v6 = input[idx6] * scale;
float v7 = input[idx7] * scale;
window_max = fmaxf(fmaxf(v0, v1), fmaxf(v2, v3));
window_max = fmaxf(window_max, fmaxf(v4, fmaxf(v5, fmaxf(v6, v7))));
} else {
// Generic loop for arbitrary pool_k.
for (int kd = 0; kd < pool_k; kd++) {
for (int kh = 0; kh < pool_k; kh++) {
for (int kw = 0; kw < pool_k; kw++) {
int d_cur = d_start + kd;
int h_cur = h_start + kh;
int w_cur = w_start + kw;
int index = (((n * C + c) * D + d_cur) * H + h_cur) * W + w_cur;
float val = input[index] * scale;
if (val > window_max)
window_max = val;
}
}
}
}
local_sum += window_max;
}
shmem[tid] = local_sum;
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shmem[tid] += shmem[tid + s];
}
__syncthreads();
}
// Thread 0 computes the average and clamps the result.
if (tid == 0) {
float avg = shmem[0] / num_windows;
avg = fminf(fmaxf(avg, 0.0f), 1.0f);
output[idx] = avg;
}
}
torch::Tensor forward(torch::Tensor x, float scale, int pool_k) {
// x has shape [N, C, D, H, W].
c10::cuda::CUDAGuard device_guard(x.device());
int N = x.size(0);
int C = x.size(1);
int D = x.size(2);
int H = x.size(3);
int W = x.size(4);
// Allocate output tensor with shape [N, C, 1, 1, 1].
auto output = torch::empty({N, C, 1, 1, 1}, x.options());
int total_blocks = N * C; // each block computes one (n,c) pair result
int threads = 128; // chosen to provide enough parallelism over pooling windows.
size_t shmem_size = threads * sizeof(float);
fused_post_ops_optimized_kernel<<<total_blocks, threads, shmem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), output.data_ptr<float>(),
N, C, D, H, W, scale, pool_k);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused post-transposed conv ops: scale, 3D max pool, global avg pool, clamp (optimized with unrolling)");
}