2.78. 3D Transposed Convolution with Cascaded Max Pooling and Summation

Performs a 3D transposed convolution to upsample input data with configurable parameters, followed by two successive max pooling operations using different kernel sizes to downsample the spatial dimensions, and finally computes a channel-wise summation.

Progress

Step 4 of 4
Reward
1.13x

Optimization Timeline

Step 4 Details

Reasoning

We can further reduce memory latency by using the __ldg intrinsic to load read‐only data from global memory via the read-only cache. In our fused kernel, the inner loop repeatedly loads elements from Y to compute the 2x2x2 max pooling. By replacing each load (e.g. Y[idx0]) with __ldg(&Y[idx0]), we can improve cache hit rates and reduce memory latency. This optimization is valid since Y is not modified. We also verified that all index computations remain correct and that the kernel uses already precomputed strides and unrolled loops. The Python module remains unchanged because the heavy lifting is entirely in CUDA. These improvements genuinely help performance without altering the algorithm.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>

// Fused kernel that performs two successive 3D max pooling operations and then
// sums the results across channels.
// Input Y: shape (N, C, D, H, W)
// First pooling: kernel=2, stride=2 -> produces an intermediate tensor P
// Second pooling: kernel=3, stride=3 -> produces tensor Q of shape (N, C, D2, H2, W2)
// Finally, S(n, d2, h2, w2) = sum_{c=0}^{C-1} Q(n, c, d2, h2, w2)
__global__ void fused_pool_sum_kernel(const float* __restrict__ Y,
                                      float* __restrict__ S,
                                      int N, int C, int D, int H, int W,
                                      int D2, int H2, int W2) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * D2 * H2 * W2;
    if (idx < total) {
        // Decompose idx into output indices: w2, h2, d2, n.
        int w2 = idx % W2;
        int tmp = idx / W2;
        int h2 = tmp % H2;
        int n_d2 = tmp / H2;
        int d2 = n_d2 % D2;
        int n = n_d2 / D2;
        
        float sum = 0.0f;
        // Precompute common strides.
        const int HW = H * W;
        const int stride_d = HW;  // stride for d dimension: H*W
        const int stride_h = W;   // stride for h dimension: W
        
        // For each channel.
        for (int c = 0; c < C; c++) {
            float max_val = -FLT_MAX;
            // Loop over second pooling window (3x3x3)
            #pragma unroll
            for (int p = 0; p < 3; p++) {
                #pragma unroll
                for (int q = 0; q < 3; q++) {
                    #pragma unroll
                    for (int r = 0; r < 3; r++) {
                        int i1 = d2 * 3 + p;
                        int j1 = h2 * 3 + q;
                        int k1 = w2 * 3 + r;
                        int base_d = i1 * 2;
                        int base_h = j1 * 2;
                        int base_w = k1 * 2;
                        // Correct base offset: account for full D dimension.
                        int base_offset = ((n * C + c) * D * HW);
                        // Unrolled computation for the 2x2x2 first pooling window using __ldg for read-only caching.
                        int idx0 = base_offset + (base_d * stride_d) + (base_h * stride_h) + base_w;
                        int idx1 = base_offset + (base_d * stride_d) + (base_h * stride_h) + (base_w + 1);
                        int idx2 = base_offset + (base_d * stride_d) + ((base_h + 1) * stride_h) + base_w;
                        int idx3 = base_offset + (base_d * stride_d) + ((base_h + 1) * stride_h) + (base_w + 1);
                        int idx4 = base_offset + ((base_d + 1) * stride_d) + (base_h * stride_h) + base_w;
                        int idx5 = base_offset + ((base_d + 1) * stride_d) + (base_h * stride_h) + (base_w + 1);
                        int idx6 = base_offset + ((base_d + 1) * stride_d) + ((base_h + 1) * stride_h) + base_w;
                        int idx7 = base_offset + ((base_d + 1) * stride_d) + ((base_h + 1) * stride_h) + (base_w + 1);
                        
                        float local_max = __ldg(&Y[idx0]);
                        local_max = fmaxf(local_max, __ldg(&Y[idx1]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx2]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx3]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx4]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx5]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx6]));
                        local_max = fmaxf(local_max, __ldg(&Y[idx7]));
                        
                        max_val = fmaxf(max_val, local_max);
                    }
                }
            }
            sum += max_val;
        }
        int out_index = ((n * D2 + d2) * H2 + h2) * W2 + w2;
        S[out_index] = sum;
    }
}

torch::Tensor forward(torch::Tensor Y,
                      int N, int C, int D, int H, int W,
                      int D2, int H2, int W2) {
    c10::cuda::CUDAGuard device_guard(Y.device());
    
    auto S = torch::empty({N, D2, H2, W2}, Y.options());
    
    int total = N * D2 * H2 * W2;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_pool_sum_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        Y.data_ptr<float>(),
        S.data_ptr<float>(),
        N, C, D, H, W,
        D2, H2, W2
    );
    return S;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused 3D max pooling (2x then 3x) and channel sum with optimized index computation and __ldg loads");
}