2.89. 3D Transposed Convolution and Activation Sequence
Performs a sequence of computations on 3D tensor data: up-samples input using a transposed convolution, reduces spatial dimensions via max pooling, normalizes channel values using softmax, subtracts a learned channel-wise bias, applies a swish activation (multiplying the input by its sigmoid), and finally aggregates channel information by selecting the maximum value across channels.
Progress
Optimization Timeline
Step 1 Details
Reasoning
We observe that after the ConvTranspose3d and MaxPool3d layers the model applies a series of element‐wise and reduction operations along the channel dimension: • Softmax across channels (dim=1) • Subtraction of a learnable per‐channel parameter • Swish activation (x * sigmoid(x)) • Channel reduction by taking the maximum over channels Noticing that the channel dimension (out_channels) is small (16 in our example), we can fuse these four operations into one custom CUDA kernel. In that kernel each thread is assigned to one spatial location (i.e. one index among (b, d, h, w)). For that spatial location, we loop over the channel dimension (c) to perform a softmax reduction. We first compute the maximum for numerical stability; then we compute the sum of exponentials; next, we compute the softmax value for each channel, subtract the learnable parameter (broadcasted from a 1D tensor), and then apply the Swish activation (z * sigmoid(z)). Finally, we take the maximum over channels and write that value out to the output. We fuse these steps to avoid multiple kernel launches and redundant memory accesses. Our CUDA kernel uses a simple one-dimensional grid where each thread processes one spatial location. We use c10::cuda::CUDAGuard to ensure correct device usage and launch the kernel using c10::cuda::getCurrentCUDAStream(). The Python module loads the fused CUDA operator and defines a CUDAModel class. In forward(), the model first applies convTranspose3d and maxPool3d using native PyTorch implementations, and then calls the fused CUDA kernel to perform softmax, subtraction, swish activation, and channel max reduction in one go. Before finishing, I ask: Have these improvements genuinely helped? Yes – by fusing several element‐wise and reduction operations into a single kernel, we significantly reduce overhead and memory traffic while operating on a small channel dimension. This is a genuine optimization rather than a workaround.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <math.h>
// Fused kernel that performs the following operations on the input tensor X and subtract parameter:
// 1. For each spatial location (b, d, h, w), compute softmax over channels (dim=1) of X.
// 2. Subtract the learnable parameter for each channel.
// 3. Apply Swish activation: swish(z) = z * sigmoid(z), where sigmoid(z)=1/(1+exp(-z)).
// 4. Reduce over channels by taking the maximum value.
// Input X has shape [B, C, D, H, W] and 'subtract' is a 1D tensor of length C.
// Output tensor shape is [B, D, H, W].
__global__ void fused_postpool_kernel(const float* __restrict__ X,
const float* __restrict__ subtract,
float* __restrict__ output,
int B, int C, int D, int H, int W) {
// Each thread processes one spatial location.
int total = B * D * H * W;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;
// Compute indices:
// Let spatial size = D*H*W. For thread index idx, we deduce:
int spatial = D * H * W;
int b = idx / spatial;
int rem = idx % spatial;
int d = rem / (H * W);
int rem2 = rem % (H * W);
int h = rem2 / W;
int w = rem2 % W;
// Compute base offset for this spatial location in X.
// X is stored in NCDHW order.
int CHW = C * D * H * W;
int HW = H * W;
int base = b * (C * D * H * W) + d * (H * W) + h * W + w;
int stride = D * H * W; // Distance to next channel for same spatial location.
// First pass: compute maximum value for softmax stability.
float max_val = -FLT_MAX;
// Temporary buffer to hold exp values per channel (assumed C <= 64).
float exp_vals[64];
for (int c = 0; c < C; c++) {
float val = X[base + c * stride];
if (val > max_val) {
max_val = val;
}
}
// Second pass: compute the sum of exp(val - max_val) for softmax.
float sum_exp = 0.0f;
for (int c = 0; c < C; c++) {
float val = X[base + c * stride];
float exp_val = expf(val - max_val);
exp_vals[c] = exp_val;
sum_exp += exp_val;
}
// Third pass: compute softmax, subtract subtract[c], apply swish (z*sigmoid(z)),
// and find the maximum value across channels.
float max_swish = -FLT_MAX;
for (int c = 0; c < C; c++) {
float softmax_val = exp_vals[c] / sum_exp;
// Subtract the parameter for this channel.
float z = softmax_val - subtract[c];
// Compute sigmoid(z)
float sigmoid_z = 1.0f / (1.0f + expf(-z));
// Swish activation.
float swish = z * sigmoid_z;
if (swish > max_swish) {
max_swish = swish;
}
}
// Write the maximum swish value to output.
output[idx] = max_swish;
}
torch::Tensor forward(torch::Tensor X, torch::Tensor subtract) {
// Ensure we use the proper CUDA device.
c10::cuda::CUDAGuard device_guard(X.device());
// X shape: [B, C, D, H, W]
int B = X.size(0);
int C = X.size(1);
int D = X.size(2);
int H = X.size(3);
int W = X.size(4);
// Output shape: [B, D, H, W]
auto output = torch::empty({B, D, H, W}, X.options());
int total = B * D * H * W;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_postpool_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
X.data_ptr<float>(),
subtract.data_ptr<float>(),
output.data_ptr<float>(),
B, C, D, H, W
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused softmax, subtract, swish, and channel max (post pool) in one CUDA kernel");
}