3.27. Multi-Stage Image Feature Extraction and Classification

Calculates classification outputs by extracting and consolidating image features across multiple stages. Each stage performs convolution, normalization, non-linear activation, and pooling operations, followed by global averaging and a fully connected transformation to generate prediction scores.

Progress

Step 5 of 5
Reward
1.06x

Optimization Timeline

Step 5 Details

Reasoning

Our fused kernel for GAP+FC was already improved by vectorizing the inner loop and using warp-level reduction. However, our performance improvement versus native is only marginal (~2.5% faster; reward ~1.025). We can further optimize the inner loop reduction by unrolling the vectorized loop over the spatial dimension. When the number of spatial elements is divisible by 4, we can treat the data as float4 data; then we unroll the loop by a factor of 2 to reduce the loop overhead and improve ILP (instruction-level parallelism). This additional unrolling may help further overlap arithmetic with memory operations. The Python module remains essentially the same, only the underlying CUDA kernel (version v3) is modified to incorporate unrolling in the spatial summation loop. We continue to use c10::cuda utilities and load the extension with -O3 and --use_fast_math extra CUDA flags. Below is our final updated code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

#ifndef WARPSIZE
#define WARPSIZE 32
#endif

// Fused GAP+FC kernel (version 3, unrolled) for input of shape [B, C, H, W].
// fc_weight has shape [out_classes, C] stored in row-major order (index = j * C + c).
// Computes for each sample b and output class j:
//   y[b,j] = (1/(H*W)) * Σ₍c=0₎^(C-1) { (Σ₍h,w₎ x[b,c,h,w]) * fc_weight[j * C + c] } + fc_bias[j]
__global__ void fused_gap_fc_v3_kernel(const float* __restrict__ input,
                                         const float* __restrict__ fc_weight,
                                         const float* __restrict__ fc_bias,
                                         float* __restrict__ output,
                                         int B, int C, int H, int W, int out_classes) {
    int b = blockIdx.x;  // batch index
    int j = blockIdx.y;  // output class index
    int spatial = H * W;
    float partial = 0.0f;

    // Each thread processes a subset of channels.
    for (int c = threadIdx.x; c < C; c += blockDim.x) {
        float gap = 0.0f;
        int base = b * (C * spatial) + c * spatial;
        // If spatial dimension is divisible by 4, use vectorized loads.
        if ((spatial & 3) == 0) {
            int vecCount = spatial >> 2;  // spatial / 4
            const float4* ptr = reinterpret_cast<const float4*>(input + base);
            int i = 0;
            // Unroll two iterations at a time.
            for (; i + 1 < vecCount; i += 2) {
                float4 temp1 = ptr[i];
                float4 temp2 = ptr[i + 1];
                gap += temp1.x + temp1.y + temp1.z + temp1.w;
                gap += temp2.x + temp2.y + temp2.z + temp2.w;
            }
            if (i < vecCount) {
                float4 temp = ptr[i];
                gap += temp.x + temp.y + temp.z + temp.w;
            }
        } else {
            for (int idx = 0; idx < spatial; idx++) {
                gap += input[base + idx];
            }
        }
        partial += gap * fc_weight[j * C + c];
    }
    
    // Warp-level reduction.
    unsigned int mask = 0xffffffff;
    for (int offset = WARPSIZE/2; offset > 0; offset /= 2) {
        partial += __shfl_down_sync(mask, partial, offset);
    }
    
    // Shared memory reduction across warps.
    extern __shared__ float sdata[];
    int lane = threadIdx.x & (WARPSIZE - 1);
    int warpId = threadIdx.x / WARPSIZE;
    if (lane == 0) {
        sdata[warpId] = partial;
    }
    __syncthreads();
    
    float total = 0.0f;
    int numWarps = (blockDim.x + WARPSIZE - 1) / WARPSIZE;
    if (threadIdx.x < numWarps) {
        total = sdata[threadIdx.x];
        for (int offset = WARPSIZE/2; offset > 0; offset /= 2) {
            total += __shfl_down_sync(mask, total, offset);
        }
        if (threadIdx.x == 0) {
            output[b * out_classes + j] = total / float(spatial) + fc_bias[j];
        }
    }
}

torch::Tensor fused_gap_fc(torch::Tensor input, torch::Tensor fc_weight, torch::Tensor fc_bias) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    int B = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    int out_classes = fc_bias.size(0);
    
    auto output = torch::empty({B, out_classes}, input.options());
    
    // Launch one block per (b, j) pair.
    dim3 grid(B, out_classes);
    int threads = 256;
    size_t shared_mem_size = (threads / WARPSIZE) * sizeof(float);
    
    fused_gap_fc_v3_kernel<<<grid, threads, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        fc_weight.data_ptr<float>(),
        fc_bias.data_ptr<float>(),
        output.data_ptr<float>(),
        B, C, H, W, out_classes
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_gap_fc, "Fused Global Average Pooling and Fully Connected operator (v3 optimized, unrolled)");
}