2.13. 3D Transposed Convolution and Sequential Activations
Computes a series of operations on a 3D input: it first applies a transposed convolution to upsample the data, then calculates the mean along the channel dimension, adds a bias, applies softmax for normalization, uses the tanh activation, and finally scales the output.
Progress
Optimization Timeline
Step 1 Details
Reasoning
The original model performs several operations: 1. A ConvTranspose3d that transforms the input tensor from shape [batch, in_channels, D_in, H_in, W_in] to [batch, out_channels, D_out, H_out, W_out]. The output dimensions are computed as: D_out = (D_in - 1) * stride - 2 * padding + kernel_size, H_out = (H_in - 1) * stride - 2 * padding + kernel_size, W_out = (W_in - 1) * stride - 2 * padding + kernel_size. 2. A mean pooling over the channel dimension that reduces the tensor to shape [batch, 1, D_out, H_out, W_out]. 3. An addition of a bias tensor of shape (1, 1, 1, 1, 1). 4. A softmax over the channel dimension. Since that dimension has a single value, softmax always yields 1. 5. A tanh activation. tanh(1) is a constant (~0.7616). 6. Scaling by a scaling factor. Thus, mathematically the final fused output is a constant tensor: each element equals tanh(1) * scaling_factor. Our custom CUDA operator will compute the output shape based on the input tensor shape and the ConvTranspose3d parameters (kernel_size, stride, padding). It will then fill the output tensor with tanhf(1.0f) * scaling_factor. We use c10::cuda::CUDAGuard to ensure device correctness and launch a kernel to fill the buffer efficiently. The Python module loads this CUDA extension with optimization flags and uses it within a CUDAModel class that matches the expected interface.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Kernel to fill an array with a constant value.
__global__ void fill_constant_kernel(float* __restrict__ output, int total, float value) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
output[idx] = value;
}
}
// Fused forward function that computes the output shape based on a ConvTranspose3d operation.
// For each spatial dimension, using the formula (assuming stride s, padding p, kernel k):
// out_dim = (in_dim - 1) * s - 2*p + k
// Then, the mean pooling over channels reduces the output to have 1 channel.
// Finally, softmax over a single channel returns 1, and tanh(1) * scaling_factor is applied.
torch::Tensor forward(torch::Tensor x,
int kernel_size,
int stride,
int padding,
float scaling_factor) {
// Ensure we work on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// Get input dimensions: x is [batch, in_channels, D_in, H_in, W_in].
int batch_size = x.size(0);
int D_in = x.size(2);
int H_in = x.size(3);
int W_in = x.size(4);
// Compute output dimensions from the ConvTranspose3d parameters.
int D_out = (D_in - 1) * stride - 2 * padding + kernel_size;
int H_out = (H_in - 1) * stride - 2 * padding + kernel_size;
int W_out = (W_in - 1) * stride - 2 * padding + kernel_size;
// After ConvTranspose3d, the output would have shape [batch, out_channels, D_out, H_out, W_out].
// Then, taking the mean over the channel dimension yields [batch, 1, D_out, H_out, W_out].
auto output = torch::empty({batch_size, 1, D_out, H_out, W_out}, x.options());
// The chain of operations reduces to a constant value:
// After addition and softmax (over the 1 channel) we have 1, then tanh(1) is computed, which is constant.
// Multiply by scaling_factor gives:
float constant = tanhf(1.0f) * scaling_factor;
int total = batch_size * 1 * D_out * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
// Launch the kernel to fill the output tensor with the constant.
fill_constant_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<float>(), total, constant);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused transposed 3D conv, mean pooling, addition, softmax, tanh, and scaling (optimized to constant)");
}