2.60. 3D Transposed Convolution with Swish and HardSwish

Performs a 3D transposed convolution to upsample the input, then applies a Swish activation computed as x multiplied by sigmoid(x), follows with group normalization, and finally uses the HardSwish activation to produce the output.

Progress

Step 3 of 3
Reward
1.20x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our kernel by vectorizing the heavy per‐group loop. In many real cases the group size (channels per group × D × H × W) is divisible by 4, so we vectorize loads and stores using float4. In the first pass, each thread will process groups of 4 floats, computing swish (x * sigmoid(x)) for each component; it will accumulate the sum and sum-of-squares over these 4 components. For groups whose size is not divisible by 4 we fall back to the original scalar loop. Similarly, in the second pass, we use vectorized loads/stores if possible. The warp‐level reduction using __shfl_down_sync remains unchanged. This optimization reduces the number of iterations and memory transactions when possible. The Python module remains essentially the same.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel that applies Swish, GroupNorm, and HardSwish
// to the output of ConvTranspose3d.
// For input tensor x of shape [B, C, D, H, W] with C divisible by groups,
// each block processes one (sample, group).
// The algorithm:
// 1. First pass: Each thread computes swish = x * sigmoid(x)
//    and accumulates partial sums and sum of squares over the group.
//    The code is vectorized if (group_size % 4 == 0).
// 2. Warp-level and shared memory reduction is used to compute the mean and variance.
// 3. Second pass: Each thread normalizes its data and applies HardSwish:
//    hard_swish(norm) = norm * clamp(norm+3, 0, 6)/6.
__global__ void fused_post_kernel(const float* __restrict__ input, float* __restrict__ output,
                                  int B, int C, int D, int H, int W,
                                  int groups, float eps) {
    // Determine sample and group indices.
    int block_id = blockIdx.x;
    int g = block_id % groups;      // group index
    int b = block_id / groups;      // sample index

    int channels_per_group = C / groups;
    int group_size = channels_per_group * D * H * W; // number of floats in the group
    int sample_offset = b * (C * D * H * W);
    int group_offset = g * channels_per_group * D * H * W;
    int base = sample_offset + group_offset;

    float partial_sum = 0.0f;
    float partial_sum_sq = 0.0f;
    
    // Determine if we can vectorize: if group_size is divisible by 4.
    bool use_vec = ((group_size & 3) == 0); // equivalent to (group_size % 4 == 0)
    
    if (use_vec) {
        int vec_length = group_size / 4;
        const float4* input_vec = reinterpret_cast<const float4*>(input + base);
        float4* output_vec = reinterpret_cast<float4*>(output + base);
        for (int i = threadIdx.x; i < vec_length; i += blockDim.x) {
            float4 in_val = input_vec[i];
            float4 swish_val;
            // Process each component with swish = x * sigmoid(x)
            swish_val.x = in_val.x * (1.0f / (1.0f + expf(-in_val.x)));
            swish_val.y = in_val.y * (1.0f / (1.0f + expf(-in_val.y)));
            swish_val.z = in_val.z * (1.0f / (1.0f + expf(-in_val.z)));
            swish_val.w = in_val.w * (1.0f / (1.0f + expf(-in_val.w)));
            output_vec[i] = swish_val;
            float s = swish_val.x + swish_val.y + swish_val.z + swish_val.w;
            float s_sq = swish_val.x * swish_val.x + swish_val.y * swish_val.y +
                         swish_val.z * swish_val.z + swish_val.w * swish_val.w;
            partial_sum += s;
            partial_sum_sq += s_sq;
        }
    } else {
        for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
            int idx = base + i;
            float val = input[idx];
            float sig = 1.0f / (1.0f + expf(-val));
            float swish = val * sig;
            output[idx] = swish;
            partial_sum += swish;
            partial_sum_sq += swish * swish;
        }
    }
    
    // Warp-level reduction using __shfl_down_sync.
    unsigned int mask = 0xffffffff;
    float sum = partial_sum;
    float sum_sq = partial_sum_sq;
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(mask, sum, offset);
        sum_sq += __shfl_down_sync(mask, sum_sq, offset);
    }
    
    // Each warp's lane 0 writes its reduced sum to shared memory.
    __shared__ float warp_sums[32];      // sufficient for up to 1024 threads per block
    __shared__ float warp_sums_sq[32];
    int lane = threadIdx.x & (warpSize - 1);
    int warp_id = threadIdx.x >> 5;
    if (lane == 0) {
        warp_sums[warp_id] = sum;
        warp_sums_sq[warp_id] = sum_sq;
    }
    __syncthreads();
    
    // First warp reduces the warp sums.
    float block_sum = 0.0f;
    float block_sum_sq = 0.0f;
    int num_warps = (blockDim.x + warpSize - 1) / warpSize;
    if (threadIdx.x < num_warps) {
        block_sum = warp_sums[threadIdx.x];
        block_sum_sq = warp_sums_sq[threadIdx.x];
    }
    if (threadIdx.x < warpSize) {
        for (int offset = warpSize / 2; offset > 0; offset /= 2) {
            block_sum += __shfl_down_sync(mask, block_sum, offset);
            block_sum_sq += __shfl_down_sync(mask, block_sum_sq, offset);
        }
        if (threadIdx.x == 0) {
            warp_sums[0] = block_sum;      // store mean sum temporarily
            warp_sums_sq[0] = block_sum_sq;  // store sum of squares temporarily
        }
    }
    __syncthreads();
    
    float mean = warp_sums[0] / group_size;
    float var = warp_sums_sq[0] / group_size - mean * mean;
    float inv_std = rsqrtf(var + eps);
    
    // Second pass: normalize and apply HardSwish.
    // HardSwish: hswish(x) = x * clamp(x+3, 0, 6)/6.
    if (use_vec) {
        int vec_length = group_size / 4;
        float4* output_vec = reinterpret_cast<float4*>(output + base);
        for (int i = threadIdx.x; i < vec_length; i += blockDim.x) {
            float4 val = output_vec[i];
            // Process and update each component.
            float norm_x = (val.x - mean) * inv_std;
            float hs_x = norm_x + 3.0f;
            hs_x = fminf(fmaxf(hs_x, 0.0f), 6.0f);
            val.x = norm_x * hs_x / 6.0f;

            float norm_y = (val.y - mean) * inv_std;
            float hs_y = norm_y + 3.0f;
            hs_y = fminf(fmaxf(hs_y, 0.0f), 6.0f);
            val.y = norm_y * hs_y / 6.0f;

            float norm_z = (val.z - mean) * inv_std;
            float hs_z = norm_z + 3.0f;
            hs_z = fminf(fmaxf(hs_z, 0.0f), 6.0f);
            val.z = norm_z * hs_z / 6.0f;

            float norm_w = (val.w - mean) * inv_std;
            float hs_w = norm_w + 3.0f;
            hs_w = fminf(fmaxf(hs_w, 0.0f), 6.0f);
            val.w = norm_w * hs_w / 6.0f;

            output_vec[i] = val;
        }
    } else {
        for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
            int idx = base + i;
            float swish = output[idx];
            float norm = (swish - mean) * inv_std;
            float hs = norm + 3.0f;
            hs = fminf(fmaxf(hs, 0.0f), 6.0f);
            output[idx] = norm * hs / 6.0f;
        }
    }
}

torch::Tensor fused_forward(torch::Tensor x, int groups, float eps) {
    // Ensure operations occur on the proper CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // Input tensor x: [B, C, D, H, W]
    int B = x.size(0);
    int C = x.size(1);
    int D = x.size(2);
    int H = x.size(3);
    int W = x.size(4);
    
    auto output = torch::empty_like(x);
    
    // Launch one block per (sample, group).
    int total_blocks = B * groups;
    int threads = 256;
    // Allocate shared memory: two arrays of size (threads/warpSize) floats.
    int num_warps = (threads + 31) / 32;
    size_t shared_mem = num_warps * 2 * sizeof(float);
    
    fused_post_kernel<<<total_blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        output.data_ptr<float>(),
        B, C, D, H, W, groups, eps
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_forward, "Fused Swish, GroupNorm, and HardSwish after ConvTranspose3d (vectorized optimization)");
}