2.36. Transposed Convolution with Aggregated Activation

Performs a 2D transposed convolution to upsample input feature maps using configurable stride, padding, and output padding, then computes the minimum across the channel dimension followed by summing over the height dimension. Finally, it applies the Gaussian Error Linear Unit (GELU) activation using its tanh-based approximation and adds a bias to the activated output.

Progress

Step 1 of 1
Reward
1.09x

Optimization Timeline

Step 1 Details

Reasoning

We need to fuse several operations into one custom kernel. The original model performs:  1. A ConvTranspose2d on input x (shape [N, in_channels, H_in, W_in]) producing y of shape [N, out_channels, H, W], where H and W are computed from the conv-transpose parameters.  2. A minimum over the channel dimension, yielding z = min(y, dim=1) with shape [N, 1, H, W].  3. A sum over the height dimension, reducing z to s with shape [N, 1, 1, W].  4. A GELU activation applied elementwise on s.  5. An addition of a bias (of shape [out_channels, 1, 1]). When added, broadcasting replicates the computed scalar across channels, so the final output has shape [N, out_channels, 1, W]. We can fuse steps 2–5 into one CUDA kernel. Our plan is to take the conv-transpose output (computed by PyTorch’s fast built-in ConvTranspose2d) as input to our fused kernel along with the bias. Then, for each sample n and each output width position w, each kernel thread will:  • Loop over the height dimension (h = 0,…,H–1). For each h, it computes the minimum value over the channel dimension (c = 0,…,C–1) at that (n, h, w) location.  • Sum these minimum values over h.  • Apply the GELU activation on the sum.  • Then write out, for each channel c, the value (gelu(sum)+bias[c]) into the output tensor such that the final output shape is [N, out_channels, 1, W]. This fused operator then returns the final output. We implement the kernel with one thread per (n, w) pair (total threads = N*W) and use a simple loop over h and c within each thread. The GELU function is implemented in device code. The Python module loads the extension with -O3 and --use_fast_math flags, and the CUDAModel class uses PyTorch’s conv_transpose for the first step and our fused kernel for the rest. Before finishing, we check that the output shape is correct: our kernel produces an output of shape [N, C, 1, W] which matches the broadcasting behavior in the original model. Our improvements truly help because we avoid launching multiple kernels for each operation and fuse the min, sum, activation, and addition steps into one kernel.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>

// Device function: approximate GELU using tanh variant.
__device__ float gelu(float x) {
    const float kAlpha = 0.044715f;
    const float kBeta = 0.7978845608028654f;  // sqrt(2/pi)
    return 0.5f * x * (1.0f + tanhf(kBeta * (x + kAlpha * x * x * x)));
}

// Fused kernel: 
// Input: conv_transpose output tensor "input" of shape [N, C, H, W]
// Bias: given as a float array of length C
// Output: tensor of shape [N, C, 1, W]
//
// For each (n, w) pair, the kernel computes:
//    s = sum_{h=0}^{H-1} ( min_{c in 0..C-1} input[n, c, h, w] )
//    a = gelu(s)
// Then, for each channel c, output[n, c, 0, w] = a + bias[c]
__global__ void fused_kernel(const float* __restrict__ input,
                             const float* __restrict__ bias,
                             float* __restrict__ output,
                             int N, int C, int H, int W) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * W; // one thread per (n, w) pair
    if (idx < total) {
        int n = idx / W;
        int w = idx % W;
        float sum = 0.0f;
        // Loop over height dimension.
        for (int h = 0; h < H; h++) {
            float min_val = FLT_MAX;
            // Loop over channel dimension to compute minimum.
            for (int c = 0; c < C; c++) {
                // Index for input[n, c, h, w]:
                int input_index = ((n * C + c) * H + h) * W + w;
                float val = input[input_index];
                if (val < min_val) {
                    min_val = val;
                }
            }
            sum += min_val;
        }
        // Apply GELU to the summed value.
        float activated = gelu(sum);
        // Write the result to each channel of the output for this (n, w) pair.
        // Output tensor shape is [N, C, 1, W], so index = n * C * W + c * W + w.
        for (int c = 0; c < C; c++) {
            int out_index = (n * C + c) * W + w;  // since the 3rd dimension is 1, we omit it.
            output[out_index] = activated + bias[c];
        }
    }
}

torch::Tensor forward(torch::Tensor input, torch::Tensor bias) {
    // input is expected to be the output of conv_transpose with shape [N, C, H, W].
    c10::cuda::CUDAGuard device_guard(input.device());
    
    int N = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    
    // Allocate output tensor with shape [N, C, 1, W].
    auto output = torch::empty({N, C, 1, W}, input.options());
    
    // Total number of (n, w) pairs.
    int total = N * W;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        N, C, H, W
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused ConvTranspose2d min, sum, GELU, and addition");
}