2.91. Transposed Convolution with Softmax and Sigmoid Activation

Performs a series of operations on an input tensor: first, a transposed convolution for upsampling; then, a softmax operation to normalize channel-wise outputs; next, adds a bias term; scales the result by a constant factor; and finally applies a sigmoid function to produce the final activated output.

Progress

Step 1 of 1
Reward
1.20x

Optimization Timeline

Step 1 Details

Reasoning

In this problem we want to fuse several operations that follow a transposed convolution. The original operations are: 1. A transposed convolution (ConvTranspose2d) whose output shape is computed by   out_dim = (in_dim - 1)*stride – 2*padding + kernel_size + output_padding   (for both height and width). This produces a tensor of shape [B, C, H_out, W_out]. 2. A softmax operation applied along the channel dimension. For each spatial location (b, h, w) the softmax is computed on the C values. 3. An addition of a bias term (of shape [C] after flattening from (C,1,1)), broadcast to each spatial location. 4. A scaling by a constant scaling_factor. 5. A sigmoid activation. Fusing these means that once the transposed convolution (which we leave untouched) is computed, we perform a single custom CUDA kernel that:  • Takes the conv_transpose output (tensor of shape [B, C, D, H, W] – note that for 2D convTranspose, D is not present so our shape will be [B, C, H_out, W_out]).  • For each spatial location (b, h, w) (or more generally for each “pixel” location in the conv output), it computes the softmax along the channel dimension. This requires a reduction to compute the maximum, then the sum of exponentials.  • Then it adds the corresponding bias for that channel, scales the result, and applies the sigmoid function. Since the number of channels (out_channels) is typically small (e.g., 64) we can launch one block per spatial location, with blockDim.x equal to the number of channels. The final output has the same shape as the conv_transpose output. Because we are fusing the element‐wise operations (softmax, addition, scaling, and sigmoid) in one kernel, we dramatically reduce memory accesses. We use c10::cuda::CUDAGuard and the CUDA stream from c10::cuda::getCurrentCUDAStream() to ensure proper device context. The Python module loads the custom CUDA extension using torch.utils.cpp_extension.load with -O3 and --use_fast_math extra flags. In the fused op (implemented in the CUDA kernel below) for each (b, h, w) we:  • Load the C values into shared memory.  • Compute the maximum over channels (via a reduction).  • Compute the sum of exp(value - max) over channels.  • Then for each channel compute softmax = exp(value - max)/sumExp; add bias[c], multiply by scaling_factor, and apply sigmoid. All of these steps are fused and performed per output “pixel” (spatial location). The CUDAModel in Python first performs conv_transpose (using nn.ConvTranspose2d) and then calls the fused CUDA operator with the conv_transpose output, the bias (reshaped to 1D), and the scaling factor.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel for softmax (dim=1), addition of bias, scaling, and sigmoid.
// The input 'in' has shape [B, C, H, W].
// We launch one block per (b, h, w) pixel location. Each block has blockDim.x = C.
__global__ void fused_postprocess_kernel(const float* __restrict__ in,
                                         float* __restrict__ out,
                                         const float* __restrict__ bias,
                                         float scaling_factor,
                                         int B, int C, int H, int W) {
    // Calculate the index for the pixel location.
    // Each block is assigned one pixel location.
    int pixel_idx = blockIdx.x; // pixel index in the flattened B*H*W
    int thread_c = threadIdx.x; // channel index (0 <= thread_c < C)
    
    // Compute b, h, w from pixel_idx: pixel_idx = b * (H*W) + (h * W + w)
    int HW = H * W;
    int b = pixel_idx / HW;
    int rem = pixel_idx % HW;
    int h = rem / W;
    int w = rem % W;
    
    // Compute index for the current element: index = ((b * C + c) * H + h) * W + w.
    int index = ((b * C + thread_c) * H + h) * W + w;
    
    // Load the input value for this channel.
    float val = in[index];
    
    // Allocate dynamic shared memory for reduction, size (C) floats.
    extern __shared__ float sdata[];
    sdata[thread_c] = val;
    __syncthreads();
    
    // Step 1: Compute the maximum value in the channel vector for numerical stability.
    // Use parallel reduction.
    for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
        if (thread_c < offset) {
            sdata[thread_c] = fmaxf(sdata[thread_c], sdata[thread_c + offset]);
        }
        __syncthreads();
    }
    float max_val = sdata[0];
    __syncthreads();
    
    // Step 2: Compute exponentials and store back to shared memory.
    float exp_val = expf(val - max_val);
    sdata[thread_c] = exp_val;
    __syncthreads();
    
    // Step 3: Compute sum of exponentials using reduction.
    for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
        if (thread_c < offset) {
            sdata[thread_c] += sdata[thread_c + offset];
        }
        __syncthreads();
    }
    float sum_exp = sdata[0];
    __syncthreads();
    
    // Step 4: Compute softmax value.
    float softmax_val = exp_val / sum_exp;
    
    // Step 5: Add bias (bias is indexed by channel).
    float result = softmax_val + bias[thread_c];
    
    // Step 6: Scale.
    result *= scaling_factor;
    
    // Step 7: Apply sigmoid activation.
    result = 1.0f / (1.0f + expf(-result));
    
    // Write the result to output.
    out[index] = result;
}

torch::Tensor forward(torch::Tensor in, torch::Tensor bias, float scaling_factor) {
    // in: [B, C, H, W]
    c10::cuda::CUDAGuard device_guard(in.device());
    
    int B = in.size(0);
    int C = in.size(1);
    int H = in.size(2);
    int W = in.size(3);
    
    auto output = torch::empty_like(in);
    
    // Set grid: one block per pixel location (B * H * W)
    int grid = B * H * W;
    int threads = C;
    size_t shared_mem = threads * sizeof(float);
    
    fused_postprocess_kernel<<<grid, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        in.data_ptr<float>(),
        output.data_ptr<float>(),
        bias.data_ptr<float>(),
        scaling_factor,
        B, C, H, W
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused softmax (dim=1), add bias, scale, and sigmoid");
}