2.44. Transposed Convolution with Scaling and Global Pooling

Performs a transposed convolution on input data using defined stride, padding, and output padding. The result is then scaled by a constant multiplier and subsequently processed by two sequential mean operations over the spatial dimensions, effectively applying global average pooling twice.

Progress

Step 3 of 3
Reward
10.97x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our fused kernel by vectorizing the inner loop that sums over the spatial dimensions of each (b, ic) slice. Instead of iterating one element at a time, we check if the size (H*W) is divisible by 4 and use float4 loads to reduce the number of iterations by a factor of 4. This should improve memory throughput and reduce the per-thread computation load. The rest of the algorithm remains unchanged. The Python module remains the same, as we only changed the CUDA kernel.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// This kernel fuses ConvTranspose2d, scalar multiplication, and double global average pooling.
// For each output element corresponding to batch index b and output channel oc:
//   For each input channel ic, compute S(b,ic) = sum_{i,j} x[b,ic,i,j] reading the contiguous H*W elements.
//   Then compute acc = Σ_{ic} (S(b,ic) * weight_sum[ic,oc]).
//   Finally, compute avg = conv_bias[oc] + acc/(H_out*W_out) and output[b,oc] = multiplier * avg.
__global__ void fused_forward_kernel(const float* __restrict__ x,
                                       int B,
                                       int in_channels,
                                       int H,
                                       int W,
                                       int out_channels,
                                       int H_out,
                                       int W_out,
                                       float multiplier,
                                       const float* __restrict__ weight_sum, // shape: [in_channels, out_channels]
                                       const float* __restrict__ conv_bias,    // shape: [out_channels]
                                       float* __restrict__ output) {           // shape: [B * out_channels]
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * out_channels;
    if (idx >= total) {
        return;
    }

    int b = idx / out_channels;
    int oc = idx % out_channels;
    float acc = 0.0f;
    int size = H * W;
    // Check if we can vectorize the accumulation.
    bool useVectorized = (size % 4 == 0);
    
    // Loop over input channels.
    for (int ic = 0; ic < in_channels; ic++) {
        float sum_ic = 0.0f;
        int base = (b * in_channels + ic) * size;
        if (useVectorized) {
            int vsize = size / 4;
            const float4* x4 = reinterpret_cast<const float4*>(x);
            int base4 = (b * in_channels + ic) * vsize;
            for (int k = 0; k < vsize; k++) {
                float4 val = x4[base4 + k];
                sum_ic += val.x + val.y + val.z + val.w;
            }
        } else {
            for (int k = 0; k < size; k++) {
                sum_ic += x[base + k];
            }
        }
        // Multiply by the precomputed kernel sum for the (ic, oc) pair.
        acc += sum_ic * weight_sum[ic * out_channels + oc];
    }
    float avg = conv_bias[oc] + acc / (H_out * W_out);
    output[idx] = multiplier * avg;
}

torch::Tensor forward(torch::Tensor x,
                      int kernel_size,
                      int stride,
                      int padding,
                      int output_padding,
                      float multiplier,
                      torch::Tensor weight_sum,
                      torch::Tensor conv_bias) {
    // Enforce correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // x shape: [B, in_channels, H, W]
    int B = x.size(0);
    int in_channels = x.size(1);
    int H = x.size(2);
    int W = x.size(3);
    
    // Compute output dimensions for ConvTranspose2d.
    int H_out = (H - 1) * stride - 2 * padding + kernel_size + output_padding;
    int W_out = (W - 1) * stride - 2 * padding + kernel_size + output_padding;
    
    // Determine out_channels from weight_sum tensor (shape: [in_channels, out_channels]).
    int out_channels = weight_sum.size(1);
    
    // Allocate output tensor with shape [B, out_channels, 1, 1].
    auto options = x.options();
    torch::Tensor output = torch::empty({B, out_channels, 1, 1}, options);
    
    int total_threads = B * out_channels;
    int threads = 256;
    int blocks = (total_threads + threads - 1) / threads;
    
    fused_forward_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        B,
        in_channels,
        H,
        W,
        out_channels,
        H_out,
        W_out,
        multiplier,
        weight_sum.data_ptr<float>(),
        conv_bias.data_ptr<float>(),
        output.data_ptr<float>()
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Transposed Conv2d, Global Average Pooling, and Scaling with vectorized reduction");
}