2.49. 3D Transposed Convolution with Softmax and Sigmoid

Calculates a 3D transposed convolution on the input tensor using configurable parameters (kernel size, stride, padding, output padding) to upsample the data, then applies a softmax activation along the channel dimension to normalize the values into probabilities, and finally uses a sigmoid function to scale the output.

Progress

Step 2 of 2
Reward
1.42x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused Softmax-Sigmoid kernel by replacing the serial reductions on shared memory with warp-level reductions using CUDA’s shuffle intrinsics. The previous implementation loaded the entire row into shared memory and then looped over the channels to compute the max and sum, which is efficient when C is very small but may incur unnecessary overhead when C is moderate (e.g. 64 channels). Using __shfl_down_sync we can allow threads in a warp to perform the reduction in hardware, reducing the synchronization overhead and total latency. In our revised kernel, each block still processes one “row” (i.e. one spatial location from the flattened B*D*H*W dimensions), with one thread per channel. Each thread loads its element, performs a warp-level reduction to compute the maximum, then cooperatively reduces across warps using shared memory. We then compute exp(x - max) and similarly reduce to get the sum. Finally, the thread computes its softmax value and applies sigmoid. This new approach should speed up reductions for moderate channel sizes, achieving a better speedup over the native implementation. Before finishing I ask: Have my improvements genuinely helped? Yes – by reducing redundant shared memory reads and synchronizations, we expect a noticeable speed-up especially for larger channel dimensions.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <float.h>
#include <math.h>

#ifndef WARP_SIZE
#define WARP_SIZE 32
#endif

// Helper function: warp-level reduction for maximum.
__inline__ __device__
float warpReduceMax(float val) {
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
    }
    return val;
}

// Helper function: warp-level reduction for sum.
__inline__ __device__
float warpReduceSum(float val) {
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

// Fused kernel that performs Softmax (dim=1) and Sigmoid on an input tensor X.
// X is assumed to have shape [B, C, D, H, W] in contiguous memory.
// We flatten the spatial dimensions so that each block processes one row (of length C),
// where C is the number of channels.
__global__ void fused_softmax_sigmoid_kernel(const float* __restrict__ X,
                                               float* __restrict__ Y,
                                               int R, int C) {
    int r = blockIdx.x;  // each block corresponds to one row (flattened B*D*H*W)
    int tid = threadIdx.x;
    if (tid >= C) return; // safety check

    int index = r * C + tid;
    float x_val = X[index];

    // Step 1: compute max value in this row.
    float local_max = x_val;
    local_max = warpReduceMax(local_max);

    // Each warp writes its result to shared memory.
    __shared__ float s_max[32];  // Maximum number of warps assumed is 32.
    int warp_id = tid / WARP_SIZE;
    int lane = tid % WARP_SIZE;
    if (lane == 0) {
        s_max[warp_id] = local_max;
    }
    __syncthreads();

    // First warp computes the final maximum.
    float final_max = (tid < (C + WARP_SIZE - 1) / WARP_SIZE) ? s_max[lane] : -FLT_MAX;
    final_max = warpReduceMax(final_max);
    if (tid == 0) {
        s_max[0] = final_max;
    }
    __syncthreads();
    final_max = s_max[0];

    // Step 2: compute exponentials and sum.
    float exp_val = expf(x_val - final_max);
    float local_sum = exp_val;
    local_sum = warpReduceSum(local_sum);

    __shared__ float s_sum[32];
    if (lane == 0) {
        s_sum[warp_id] = local_sum;
    }
    __syncthreads();

    float final_sum = (tid < (C + WARP_SIZE - 1) / WARP_SIZE) ? s_sum[lane] : 0.0f;
    final_sum = warpReduceSum(final_sum);
    if (tid == 0) {
        s_sum[0] = final_sum;
    }
    __syncthreads();
    final_sum = s_sum[0];

    // Step 3: compute softmax value and apply sigmoid.
    float softmax_val = exp_val / final_sum;
    float sigmoid_val = 1.0f / (1.0f + expf(-softmax_val));
    Y[index] = sigmoid_val;
}

torch::Tensor fused_softmax_sigmoid(torch::Tensor X) {
    // Ensure operations run on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(X.device());

    // Input tensor X is expected to be 5D: [B, C, D, H, W].
    auto sizes = X.sizes();
    TORCH_CHECK(sizes.size() == 5, "Input must be a 5D tensor");
    int B = sizes[0];
    int C = sizes[1];
    int D = sizes[2];
    int H = sizes[3];
    int W = sizes[4];

    // Flatten spatial dimensions: number of rows R = B * D * H * W.
    int R = B * D * H * W;

    // Allocate output tensor Y with same shape as X.
    auto Y = torch::empty_like(X);

    // Launch one block per row and C threads per block.
    dim3 grid(R);
    dim3 block(C);
    size_t shared_mem = 32 * sizeof(float) * 2;  // for s_max and s_sum arrays.
    fused_softmax_sigmoid_kernel<<<grid, block, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        X.data_ptr<float>(),
        Y.data_ptr<float>(),
        R,
        C
    );
    
    return Y;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_softmax_sigmoid", &fused_softmax_sigmoid, "Fused Softmax (dim=1) and Sigmoid activation using warp shuffle optimizations (CUDA)");
}