2.38. 3D Transposed Convolution and Post-Processing
Performs a 3D transposed convolution with configurable stride, padding, and output padding on volumetric data; applies 3D average pooling to downsample the features; clamps the resultant values within a specified range; normalizes the clamped output with softmax along the channel dimension; and finally scales the output by multiplying by a constant.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused kernel by processing multiple pooled spatial locations concurrently in a 2D thread block. In the previous version, each CUDA block processed exactly one pooled location with C threads. However, this can be inefficient when the number of pooled locations is large because of overhead from launching many small blocks. In our improved design we use a 2D block with dimensions (blockDim.x = C, blockDim.y = T), so that each block processes T pooled positions simultaneously. For each pooled position, each thread (indexed by threadIdx.x) computes the average pooling result (over a pool_kernel³ window) for one channel. We then use shared memory arranged as a 2D array of size [T][C] (with layout “row = pooled location within the tile, column = channel”) to compute softmax per pooled location. For each row we perform a serial reduction (because C is modest, e.g. 16) to compute the maximum and the sum for the softmax. Finally, each element is scaled by scaling_factor and written to global memory. This fusion avoids launching one block per spatial location and better utilizes the GPU by processing several locations per block. Our improvements genuinely reduce kernel launch overhead and increase occupancy.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Optimized fused kernel that processes multiple pooled spatial positions per block.
// Input: tensor of shape [B, C, D_conv, H_conv, W_conv] (output of ConvTranspose3d).
// Pooling is applied with a non-overlapping window of size pool_kernel³,
// yielding pooled dimensions: D_pool = D_conv / pool_kernel, etc.
// For each pooled spatial location, softmax is computed over the channel dimension and then multiplied
// by scaling_factor. We process several pooled locations concurrently in each block.
// The thread block dimensions are: blockDim.x = C (channels), blockDim.y = T (tile size, number of pooled locations per block).
extern "C" __global__
void fused_kernel_optimized(const float* __restrict__ input,
float* __restrict__ output,
int B, int C,
int D_conv, int H_conv, int W_conv,
int pool_kernel,
float clamp_min, float clamp_max,
float scaling_factor) {
// Compute pooled dimensions.
int D_pool = D_conv / pool_kernel;
int H_pool = H_conv / pool_kernel;
int W_pool = W_conv / pool_kernel;
int total_pooled = B * D_pool * H_pool * W_pool;
// Each thread block processes a tile of T pooled positions.
// Global pooled index for this thread's tile row:
int tile_idx = blockIdx.x * blockDim.y + threadIdx.y;
if (tile_idx >= total_pooled) return;
// Decode tile_idx into (b, d_pool, h_pool, w_pool).
int spatial_size = D_pool * H_pool * W_pool;
int b = tile_idx / spatial_size;
int rem = tile_idx % spatial_size;
int d_pool = rem / (H_pool * W_pool);
rem = rem % (H_pool * W_pool);
int h_pool = rem / W_pool;
int w_pool = rem % W_pool;
// Current channel index from threadIdx.x.
int c = threadIdx.x;
if (c >= C) return; // safety
// Compute starting indices in the convolution output for the pooling window.
int d_start = d_pool * pool_kernel;
int h_start = h_pool * pool_kernel;
int w_start = w_pool * pool_kernel;
// Compute average pooling over the window for this channel.
float sum = 0.0f;
int window_size = pool_kernel * pool_kernel * pool_kernel;
for (int pd = 0; pd < pool_kernel; pd++) {
int d_in = d_start + pd;
for (int ph = 0; ph < pool_kernel; ph++) {
int h_in = h_start + ph;
for (int pw = 0; pw < pool_kernel; pw++) {
int w_in = w_start + pw;
int index = b * (C * D_conv * H_conv * W_conv)
+ c * (D_conv * H_conv * W_conv)
+ d_in * (H_conv * W_conv)
+ h_in * (W_conv)
+ w_in;
sum += input[index];
}
}
}
float avg = sum / window_size;
// Apply clamping.
float clamped = fminf(fmaxf(avg, clamp_min), clamp_max);
// Use shared memory as a 2D buffer: each row corresponds to one pooled location in the tile; row length = C.
// Shared memory size should be (blockDim.y * C) floats.
extern __shared__ float sdata[];
// Compute linear index in shared memory.
int tid = threadIdx.y * C + threadIdx.x;
sdata[tid] = clamped;
__syncthreads();
// For each tile row, compute the maximum value over channels (softmax reduction).
int row_offset = threadIdx.y * C;
float max_val = sdata[row_offset]; // initial value from channel 0.
if (threadIdx.x == 0) {
for (int i = 1; i < C; i++) {
float val = sdata[row_offset + i];
if (val > max_val)
max_val = val;
}
sdata[row_offset] = max_val; // store maximum in first element of the row.
}
__syncthreads();
max_val = sdata[row_offset]; // broadcast to all threads in this row.
// Compute exponential for softmax (shifted by max_val).
float exp_val = expf(clamped - max_val);
sdata[tid] = exp_val;
__syncthreads();
// For each tile row, compute the sum of exponentials.
float sum_exp = 0.0f;
if (threadIdx.x == 0) {
for (int i = 0; i < C; i++) {
sum_exp += sdata[row_offset + i];
}
sdata[row_offset] = sum_exp; // store sum in first element of the row.
}
__syncthreads();
sum_exp = sdata[row_offset];
// Compute softmax value and then scale.
float softmax_val = exp_val / sum_exp;
float result = softmax_val * scaling_factor;
// Write the result to the output.
// The output tensor has shape [B, C, D_pool, H_pool, W_pool].
int out_index = b * (C * D_pool * H_pool * W_pool)
+ c * (D_pool * H_pool * W_pool)
+ d_pool * (H_pool * W_pool)
+ h_pool * (W_pool)
+ w_pool;
output[out_index] = result;
}
torch::Tensor fused_forward(torch::Tensor input,
int pool_kernel,
float clamp_min,
float clamp_max,
float scaling_factor) {
// Ensure correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
// Input shape: [B, C, D_conv, H_conv, W_conv]
int B = input.size(0);
int C = input.size(1);
int D_conv = input.size(2);
int H_conv = input.size(3);
int W_conv = input.size(4);
// Compute pooled output dimensions.
int D_pool = D_conv / pool_kernel;
int H_pool = H_conv / pool_kernel;
int W_pool = W_conv / pool_kernel;
// Output shape: [B, C, D_pool, H_pool, W_pool]
auto output = torch::empty({B, C, D_pool, H_pool, W_pool}, input.options());
// Total number of pooled spatial locations.
int total_pooled = B * D_pool * H_pool * W_pool;
// Choose a tile size for the second dimension of the block.
int tile_size = 4; // You can experiment with different tile sizes.
dim3 blockDim(C, tile_size);
int blocks = (total_pooled + tile_size - 1) / tile_size;
// Shared memory size: blockDim.y * C * sizeof(float)
size_t shared_mem_size = blockDim.y * C * sizeof(float);
fused_kernel_optimized<<<blocks, blockDim, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
B, C, D_conv, H_conv, W_conv,
pool_kernel, clamp_min, clamp_max, scaling_factor
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_forward, "Fused transp_conv pool clamp softmax and scale (optimized CUDA)");
}