2.13. 3D Transposed Convolution and Sequential Activations

Computes a series of operations on a 3D input: it first applies a transposed convolution to upsample the data, then calculates the mean along the channel dimension, adds a bias, applies softmax for normalization, uses the tanh activation, and finally scales the output.

Progress

Step 1 of 1
Reward
100.64x

Optimization Timeline

Step 1 Details

Reasoning

The original model performs several operations: 1. A ConvTranspose3d that transforms the input tensor from shape [batch, in_channels, D_in, H_in, W_in] to [batch, out_channels, D_out, H_out, W_out]. The output dimensions are computed as:   D_out = (D_in - 1) * stride - 2 * padding + kernel_size,   H_out = (H_in - 1) * stride - 2 * padding + kernel_size,   W_out = (W_in - 1) * stride - 2 * padding + kernel_size. 2. A mean pooling over the channel dimension that reduces the tensor to shape [batch, 1, D_out, H_out, W_out]. 3. An addition of a bias tensor of shape (1, 1, 1, 1, 1). 4. A softmax over the channel dimension. Since that dimension has a single value, softmax always yields 1. 5. A tanh activation. tanh(1) is a constant (~0.7616). 6. Scaling by a scaling factor. Thus, mathematically the final fused output is a constant tensor: each element equals tanh(1) * scaling_factor. Our custom CUDA operator will compute the output shape based on the input tensor shape and the ConvTranspose3d parameters (kernel_size, stride, padding). It will then fill the output tensor with tanhf(1.0f) * scaling_factor. We use c10::cuda::CUDAGuard to ensure device correctness and launch a kernel to fill the buffer efficiently. The Python module loads this CUDA extension with optimization flags and uses it within a CUDAModel class that matches the expected interface.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Kernel to fill an array with a constant value.
__global__ void fill_constant_kernel(float* __restrict__ output, int total, float value) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        output[idx] = value;
    }
}

// Fused forward function that computes the output shape based on a ConvTranspose3d operation.
// For each spatial dimension, using the formula (assuming stride s, padding p, kernel k):
//   out_dim = (in_dim - 1) * s - 2*p + k
// Then, the mean pooling over channels reduces the output to have 1 channel.
// Finally, softmax over a single channel returns 1, and tanh(1) * scaling_factor is applied.
torch::Tensor forward(torch::Tensor x,
                      int kernel_size,
                      int stride,
                      int padding,
                      float scaling_factor) {
    // Ensure we work on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // Get input dimensions: x is [batch, in_channels, D_in, H_in, W_in].
    int batch_size = x.size(0);
    int D_in = x.size(2);
    int H_in = x.size(3);
    int W_in = x.size(4);
    
    // Compute output dimensions from the ConvTranspose3d parameters.
    int D_out = (D_in - 1) * stride - 2 * padding + kernel_size;
    int H_out = (H_in - 1) * stride - 2 * padding + kernel_size;
    int W_out = (W_in - 1) * stride - 2 * padding + kernel_size;
    
    // After ConvTranspose3d, the output would have shape [batch, out_channels, D_out, H_out, W_out].
    // Then, taking the mean over the channel dimension yields [batch, 1, D_out, H_out, W_out].
    auto output = torch::empty({batch_size, 1, D_out, H_out, W_out}, x.options());
    
    // The chain of operations reduces to a constant value:
    // After addition and softmax (over the 1 channel) we have 1, then tanh(1) is computed, which is constant.
    // Multiply by scaling_factor gives:
    float constant = tanhf(1.0f) * scaling_factor;
    
    int total = batch_size * 1 * D_out * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    // Launch the kernel to fill the output tensor with the constant.
    fill_constant_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        output.data_ptr<float>(), total, constant);
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused transposed 3D conv, mean pooling, addition, softmax, tanh, and scaling (optimized to constant)");
}