2.50. 3D Transposed Convolution, Scaling, Pooling & Bias Addition

Performs a series of operations on 3D input data: calculates a transposed convolution to upsample the input, scales the result, applies average pooling to reduce dimensions, adds a bias, and then applies a final scaling transformation.

Progress

Step 4 of 4
Reward
1.03x

Optimization Timeline

Step 4 Details

Reasoning

We want to lower the fused kernel’s runtime further. Our “direct” fused kernel that computes each output cell by processing its 2×2×2 pooling window is already correct but still slower than native. We can try improving memory accesses by taking the convolution weights out of global memory and storing them in constant memory (since for a typical 3×3×3 and small channel counts the weight tensor is small). We define a constant memory array (with size 4096 floats) and, if the weight size is below that threshold, we copy the weight tensor into constant memory and pass a null pointer to the kernel. In the kernel we then use the constant memory if the passed weight pointer is null. In addition, we add “#pragma unroll” directives on the small fixed loops (pooling window loops and kernel loops) to encourage loop unrolling. These changes should greatly reduce memory latency and branch overhead. We update the host forward() function in the CUDA extension to compute weight_count = C_in * C_out * kernel_size^3 and, if it is <=4096, do a cudaMemcpyToSymbol to copy the weight to constant memory and then pass nullptr in its place. In the kernel, we check if the pointer is null and use the constant memory array instead. Now we provide the updated Python and CUDA blocks.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

#define MAX_WEIGHT 4096

// Declare constant memory for weights.
__constant__ float const_weight[MAX_WEIGHT];

// Fused direct kernel computes one final output element per thread.
// Each thread aggregates contributions from its corresponding 2x2x2 pooling window.
// Loops are unrolled where possible.
__global__ void fused_direct_kernel(
    const float* __restrict__ x,
    const float* __restrict__ weight,  // if weight == nullptr, use const_weight
    const float* __restrict__ conv_bias,
    const float scale1,
    const float scale2,
    const float* __restrict__ extra_bias,
    const int stride,
    const int padding,
    const int kernel_size,
    const int D_in, const int H_in, const int W_in,
    const int B, const int C_in, const int C_out,
    const int D_out, const int H_out, const int W_out,
    float* __restrict__ output) {
    
    int total = B * C_out * D_out * H_out * W_out;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total) return;
    
    // Decode flat index.
    int w_out = idx % W_out;
    int temp = idx / W_out;
    int h_out = temp % H_out;
    temp /= H_out;
    int d_out = temp % D_out;
    temp /= D_out;
    int c = temp % C_out;
    int b = temp / C_out;
    
    const int pool_size = 2;
    float pool_sum = 0.0f;
    // Determine weight pointer (use constant memory if weight==nullptr).
    const float* w_ptr = (weight == nullptr) ? const_weight : weight;
    
    // Loop over the pooling window (2x2x2) - unroll these loops.
    #pragma unroll
    for (int pd = 0; pd < pool_size; pd++){
        int d_conv = d_out * pool_size + pd;
        #pragma unroll
        for (int ph = 0; ph < pool_size; ph++){
            int h_conv = h_out * pool_size + ph;
            #pragma unroll
            for (int pw = 0; pw < pool_size; pw++){
                int w_conv = w_out * pool_size + pw;
                float conv_val = 0.0f;
                // Loop over input channels.
                for (int ic = 0; ic < C_in; ic++){
                    // Nested loops over kernel dimensions.
                    #pragma unroll
                    for (int kd = 0; kd < 3; kd++){
                        int id_mul = d_conv + padding - kd;
                        if (id_mul < 0 || (id_mul % stride) != 0) continue;
                        int id = id_mul / stride;
                        if (id < 0 || id >= D_in) continue;
                        #pragma unroll
                        for (int kh = 0; kh < 3; kh++){
                            int ih_mul = h_conv + padding - kh;
                            if (ih_mul < 0 || (ih_mul % stride) != 0) continue;
                            int ih = ih_mul / stride;
                            if (ih < 0 || ih >= H_in) continue;
                            #pragma unroll
                            for (int kw = 0; kw < 3; kw++){
                                int iw_mul = w_conv + padding - kw;
                                if (iw_mul < 0 || (iw_mul % stride) != 0) continue;
                                int iw = iw_mul / stride;
                                if (iw < 0 || iw >= W_in) continue;
                                int x_index = (((b * C_in + ic) * D_in + id) * H_in + ih) * W_in + iw;
                                int w_index = ((((ic * C_out) + c) * kernel_size + kd) * kernel_size + kh) * kernel_size + kw;
                                conv_val += x[x_index] * w_ptr[w_index];
                            }
                        }
                    }
                }
                conv_val += conv_bias[c];
                conv_val *= scale1;
                pool_sum += conv_val;
            }
        }
    }
    
    float avg = pool_sum / 8.0f;
    avg += extra_bias[c];
    avg *= scale2;
    int out_index = (((b * C_out + c) * D_out + d_out) * H_out + h_out) * W_out + w_out;
    output[out_index] = avg;
}

torch::Tensor forward(torch::Tensor x,
                      torch::Tensor weight,
                      torch::Tensor conv_bias,
                      torch::Tensor scale1_tensor,
                      torch::Tensor scale2_tensor,
                      torch::Tensor extra_bias,
                      int stride,
                      int padding,
                      int kernel_size,
                      int D_in, int H_in, int W_in,
                      int D_out, int H_out, int W_out) {
    c10::cuda::CUDAGuard device_guard(x.device());
    
    float scale1 = scale1_tensor.item<float>();
    float scale2 = scale2_tensor.item<float>();
    
    int C_in = weight.size(0);
    int C_out = weight.size(1);
    int B = x.size(0);
    
    // Determine weight count.
    int weight_count = C_in * C_out * kernel_size * kernel_size * kernel_size;
    const float* weight_ptr = weight.data_ptr<float>();
    // If weight_count is small, copy into constant memory and pass nullptr.
    if (weight_count <= MAX_WEIGHT) {
        cudaMemcpyToSymbol(const_weight, weight_ptr, weight_count * sizeof(float));
        weight_ptr = nullptr;
    }
    
    auto output = torch::empty({B, C_out, D_out, H_out, W_out}, x.options());
    int total = B * C_out * D_out * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_direct_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight_ptr,
        conv_bias.data_ptr<float>(),
        scale1,
        scale2,
        extra_bias.data_ptr<float>(),
        stride,
        padding,
        kernel_size,
        D_in, H_in, W_in,
        B, C_in, C_out,
        D_out, H_out, W_out,
        output.data_ptr<float>()
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused 3D Transposed Convolution + Scaling + AvgPool + BiasAdd + Scaling (direct, constant memory)");
}