3.29. Hierarchical Image Processing with Shifted Window MLP

Calculates an image representation by first embedding the image into patches and then processing these patches through a series of local, shifted window operations. The procedure partitions feature maps into windows, applies grouped convolution as a spatial MLP to each window, performs patch merging for downsampling, and finally aggregates features for classification.

Progress

Step 2 of 2
Reward
1.25x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the spatial MLP CUDA kernel by reducing redundant global memory accesses. In the original kernel we loop over "group_channels" to accumulate the weighted sum. By using the __ldg intrinsic to read from global memory (thus caching these read-only values) we can lower memory latency. We also add a branch to unroll the loop when group_channels is equal to 49 (the typical value when window_size is 7) since the compiler can then optimize the loop further. This should lead to additional speedups. The Python module does not change except for compiling with the same flags. The following updated CUDA kernel reflects these improvements.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Optimized kernel for spatial MLP (group convolution with kernel size 1)
// Input tensor shape: (N, C, L), where C = num_groups * group_channels (group_channels = window_size^2)
// Each output element is computed as:
//   output[n, c, l] = sum_{ic=0}^{group_channels-1} input[n, group_id*group_channels + ic, l] * weight[c, ic] + bias[c]
// where group_id = c / group_channels.
__global__ void mlp_forward_kernel(const float* __restrict__ input,
                                   const float* __restrict__ weight,
                                   const float* __restrict__ bias,
                                   float* __restrict__ output,
                                   int N, int C, int L, int group_channels) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * C * L;
    if (idx >= total) return;
    
    int l = idx % L;
    int tmp = idx / L;
    int c = tmp % C;
    int n = tmp / C;
    
    int group_id = c / group_channels;
    float sum_val = 0.f;
    
    // Optimize for common case where group_channels == 49 (for window_size == 7)
    if (group_channels == 49) {
        #pragma unroll
        for (int ic = 0; ic < 49; ic++) {
            int input_idx = n * (C * L) + (group_id * group_channels + ic) * L + l;
            int weight_idx = c * group_channels + ic;
            sum_val += __ldg(&input[input_idx]) * __ldg(&weight[weight_idx]);
        }
    } else {
        for (int ic = 0; ic < group_channels; ic++) {
            int input_idx = n * (C * L) + (group_id * group_channels + ic) * L + l;
            int weight_idx = c * group_channels + ic;
            sum_val += __ldg(&input[input_idx]) * __ldg(&weight[weight_idx]);
        }
    }
    
    sum_val += __ldg(&bias[c]);
    output[idx] = sum_val;
}

// C++ interface for the mlp_forward operator.
// Expects:
//  - input: tensor of shape (N, C, L)
//  - weight: tensor of shape (C, group_channels)
//  - bias: tensor of shape (C)
//  - group_channels: int (window_size^2)
torch::Tensor mlp_forward(torch::Tensor input, torch::Tensor weight, torch::Tensor bias, int group_channels) {
    input = input.contiguous();
    weight = weight.contiguous();
    bias = bias.contiguous();
    
    int N = input.size(0);
    int C = input.size(1);
    int L = input.size(2);
    int total_elements = N * C * L;
    
    auto output = torch::empty_like(input);
    
    c10::cuda::CUDAGuard device_guard(input.device());
    cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
    
    int threads = 256;
    int blocks = (total_elements + threads - 1) / threads;
    
    mlp_forward_kernel<<<blocks, threads, 0, stream>>>(
        input.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        output.data_ptr<float>(),
        N, C, L, group_channels
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("mlp_forward", &mlp_forward, "Optimized Spatial MLP Forward (CUDA)");
}