2.15. 3D Transposed Convolution with Batch Norm and Mean Subtraction

Performs a 3D transposed convolution operation with adjustable kernel, stride, and padding to upsample input data, applies batch normalization to the resulting feature maps, and calculates/subtracts the spatial mean to center the activations.

Progress

Step 2 of 2
Reward
1.04x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused subtraction kernel by improving the reduction performance in the compute_mean_kernel and the subtraction kernel. In our previous version we used a block size of 256 threads for computing the spatial sum per (n,c) slice. However, because each slice may contain many elements (e.g. 31×63×63 ≈ 123K elements), having more threads per block helps reduce the number of iterations per thread in the reduction loop. We now change the reduction kernel’s block size to 1024. Likewise, using 1024 threads in the subtraction kernel reduces the number of blocks launched and improves occupancy. These changes should yield a slight speedup by lowering per-thread looping overhead and better using the available parallelism. We make sure to use c10::cuda::CUDAGuard everywhere and deploy our kernels on c10::cuda::getCurrentCUDAStream(). Overall, our improvements are genuine optimizations (increasing thread parallelism in the reduction and subtraction steps) that will reduce runtime compared to the previous two-kernel fusion version.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Fused kernel: For each (n,c) slice, compute the mean over spatial dimensions 
// and then subtract that mean from every element in the slice.

// Kernel to compute the mean of a spatial slice (for a given [n,c] slice).
// Each block handles one complete (n,c) slice.
// The input tensor is contiguous with shape [N, C, D, H, W].
__global__ void compute_mean_kernel(const float* __restrict__ input,
                                    float* __restrict__ mean,
                                    int spatial_size) {
    // Each block processes one (n,c) slice.
    int slice_index = blockIdx.x;
    const float* slice = input + slice_index * spatial_size;
    float sum = 0.0f;
    
    // Use a larger block size (1024 threads) to reduce iterations.
    for (int pos = threadIdx.x; pos < spatial_size; pos += blockDim.x) {
        sum += slice[pos];
    }
    // Warp-level reduction using __shfl_down_sync.
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        sum += __shfl_down_sync(0xffffffff, sum, offset);
    }
    // Shared memory reduction across warps.
    __shared__ float sdata[32];  // Maximum warps per block with 1024 threads = 1024/32 = 32.
    int lane = threadIdx.x % 32;
    int warp_id = threadIdx.x / 32;
    if (lane == 0) {
        sdata[warp_id] = sum;
    }
    __syncthreads();
    // Let the first warp reduce the warp sums.
    if (threadIdx.x < blockDim.x / 32) {
        float block_sum = sdata[threadIdx.x];
        for (int offset = 16; offset > 0; offset /= 2) {
            block_sum += __shfl_down_sync(0xffffffff, block_sum, offset);
        }
        if (threadIdx.x == 0) {
            mean[slice_index] = block_sum / float(spatial_size);
        }
    }
}

// Kernel to subtract the computed mean from each element.
// The input tensor is modified in-place.
__global__ void subtract_mean_kernel(float* __restrict__ input,
                                     const float* __restrict__ mean,
                                     int spatial_size,
                                     int total_elements) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    // Process using a grid-stride loop for robust coverage.
    for (int i = idx; i < total_elements; i += blockDim.x * gridDim.x) {
        int slice_index = i / spatial_size;  // each slice corresponds to one (n,c) combination.
        input[i] -= mean[slice_index];
    }
}

torch::Tensor fused_subtract_mean(torch::Tensor x) {
    // Ensure operations are executed on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // x is assumed to be a contiguous tensor with shape [N, C, D, H, W].
    int64_t N = x.size(0);
    int64_t C = x.size(1);
    int64_t D = x.size(2);
    int64_t H = x.size(3);
    int64_t W = x.size(4);
    
    int spatial_size = D * H * W;
    int num_slices = N * C;
    
    // Allocate temporary tensor to store per-slice means.
    auto options = x.options();
    auto mean = torch::empty({num_slices}, options);
    
    // Launch kernel to compute mean for each (n,c) slice.
    int blockSizeMean = 1024;
    int gridSizeMean = num_slices;  // one block per slice.
    compute_mean_kernel<<<gridSizeMean, blockSizeMean, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        mean.data_ptr<float>(),
        spatial_size
    );
    
    // Launch kernel to subtract the computed mean from each element.
    int total_elements = x.numel();
    int blockSizeSub = 1024;
    int gridSizeSub = (total_elements + blockSizeSub - 1) / blockSizeSub;
    subtract_mean_kernel<<<gridSizeSub, blockSizeSub, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        mean.data_ptr<float>(),
        spatial_size,
        total_elements
    );
    
    return x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_subtract_mean", &fused_subtract_mean,
          "Fused operation to subtract per-(n,c) spatial mean from input tensor");
}