2.91. Transposed Convolution with Softmax and Sigmoid Activation
Performs a series of operations on an input tensor: first, a transposed convolution for upsampling; then, a softmax operation to normalize channel-wise outputs; next, adds a bias term; scales the result by a constant factor; and finally applies a sigmoid function to produce the final activated output.
Progress
Optimization Timeline
Step 1 Details
Reasoning
In this problem we want to fuse several operations that follow a transposed convolution. The original operations are: 1. A transposed convolution (ConvTranspose2d) whose output shape is computed by out_dim = (in_dim - 1)*stride – 2*padding + kernel_size + output_padding (for both height and width). This produces a tensor of shape [B, C, H_out, W_out]. 2. A softmax operation applied along the channel dimension. For each spatial location (b, h, w) the softmax is computed on the C values. 3. An addition of a bias term (of shape [C] after flattening from (C,1,1)), broadcast to each spatial location. 4. A scaling by a constant scaling_factor. 5. A sigmoid activation. Fusing these means that once the transposed convolution (which we leave untouched) is computed, we perform a single custom CUDA kernel that: • Takes the conv_transpose output (tensor of shape [B, C, D, H, W] – note that for 2D convTranspose, D is not present so our shape will be [B, C, H_out, W_out]). • For each spatial location (b, h, w) (or more generally for each “pixel” location in the conv output), it computes the softmax along the channel dimension. This requires a reduction to compute the maximum, then the sum of exponentials. • Then it adds the corresponding bias for that channel, scales the result, and applies the sigmoid function. Since the number of channels (out_channels) is typically small (e.g., 64) we can launch one block per spatial location, with blockDim.x equal to the number of channels. The final output has the same shape as the conv_transpose output. Because we are fusing the element‐wise operations (softmax, addition, scaling, and sigmoid) in one kernel, we dramatically reduce memory accesses. We use c10::cuda::CUDAGuard and the CUDA stream from c10::cuda::getCurrentCUDAStream() to ensure proper device context. The Python module loads the custom CUDA extension using torch.utils.cpp_extension.load with -O3 and --use_fast_math extra flags. In the fused op (implemented in the CUDA kernel below) for each (b, h, w) we: • Load the C values into shared memory. • Compute the maximum over channels (via a reduction). • Compute the sum of exp(value - max) over channels. • Then for each channel compute softmax = exp(value - max)/sumExp; add bias[c], multiply by scaling_factor, and apply sigmoid. All of these steps are fused and performed per output “pixel” (spatial location). The CUDAModel in Python first performs conv_transpose (using nn.ConvTranspose2d) and then calls the fused CUDA operator with the conv_transpose output, the bias (reshaped to 1D), and the scaling factor.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel for softmax (dim=1), addition of bias, scaling, and sigmoid.
// The input 'in' has shape [B, C, H, W].
// We launch one block per (b, h, w) pixel location. Each block has blockDim.x = C.
__global__ void fused_postprocess_kernel(const float* __restrict__ in,
float* __restrict__ out,
const float* __restrict__ bias,
float scaling_factor,
int B, int C, int H, int W) {
// Calculate the index for the pixel location.
// Each block is assigned one pixel location.
int pixel_idx = blockIdx.x; // pixel index in the flattened B*H*W
int thread_c = threadIdx.x; // channel index (0 <= thread_c < C)
// Compute b, h, w from pixel_idx: pixel_idx = b * (H*W) + (h * W + w)
int HW = H * W;
int b = pixel_idx / HW;
int rem = pixel_idx % HW;
int h = rem / W;
int w = rem % W;
// Compute index for the current element: index = ((b * C + c) * H + h) * W + w.
int index = ((b * C + thread_c) * H + h) * W + w;
// Load the input value for this channel.
float val = in[index];
// Allocate dynamic shared memory for reduction, size (C) floats.
extern __shared__ float sdata[];
sdata[thread_c] = val;
__syncthreads();
// Step 1: Compute the maximum value in the channel vector for numerical stability.
// Use parallel reduction.
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
if (thread_c < offset) {
sdata[thread_c] = fmaxf(sdata[thread_c], sdata[thread_c + offset]);
}
__syncthreads();
}
float max_val = sdata[0];
__syncthreads();
// Step 2: Compute exponentials and store back to shared memory.
float exp_val = expf(val - max_val);
sdata[thread_c] = exp_val;
__syncthreads();
// Step 3: Compute sum of exponentials using reduction.
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
if (thread_c < offset) {
sdata[thread_c] += sdata[thread_c + offset];
}
__syncthreads();
}
float sum_exp = sdata[0];
__syncthreads();
// Step 4: Compute softmax value.
float softmax_val = exp_val / sum_exp;
// Step 5: Add bias (bias is indexed by channel).
float result = softmax_val + bias[thread_c];
// Step 6: Scale.
result *= scaling_factor;
// Step 7: Apply sigmoid activation.
result = 1.0f / (1.0f + expf(-result));
// Write the result to output.
out[index] = result;
}
torch::Tensor forward(torch::Tensor in, torch::Tensor bias, float scaling_factor) {
// in: [B, C, H, W]
c10::cuda::CUDAGuard device_guard(in.device());
int B = in.size(0);
int C = in.size(1);
int H = in.size(2);
int W = in.size(3);
auto output = torch::empty_like(in);
// Set grid: one block per pixel location (B * H * W)
int grid = B * H * W;
int threads = C;
size_t shared_mem = threads * sizeof(float);
fused_postprocess_kernel<<<grid, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
in.data_ptr<float>(),
output.data_ptr<float>(),
bias.data_ptr<float>(),
scaling_factor,
B, C, H, W
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused softmax (dim=1), add bias, scale, and sigmoid");
}