3.25. ShuffleNet Unit

Calculates a feature transformation on an input tensor by first applying a grouped 1x1 convolution to reduce dimensionality, then a depthwise 3x3 convolution for spatial filtering. It then performs a channel shuffle operation to mix channels across groups and applies a second grouped 1x1 convolution, integrating batch normalization, ReLU activations, and a shortcut connection to preserve input identity when applicable.

Progress

Step 5 of 5
Reward
1.01x

Optimization Timeline

Step 5 Details

Reasoning

Our current implementation achieves nearly the same performance as the reference (reward ≈1.006), which indicates that our custom channel shuffle kernel is already extremely optimized with vectorized loads using float4 and an inlined index decomposition function. One possible micro-optimization is to force inlining the index decomposition function by adding __forceinline__ (or similar) to help the compiler optimize better. However, further improvements beyond that are likely to provide diminishing returns, because the operation is memory-bound and our arithmetic overhead is already minimal. Therefore, we simply add __forceinline__ to our decompose_index_vec function and leave the rest unchanged. This does not change correctness and may help achieve a slightly improved speedup.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Scalar kernel for channel shuffle (if width is not divisible by 4)
__global__ void channel_shuffle_kernel(const float* __restrict__ input,
                                       float* __restrict__ output,
                                       int B, int C, int H, int W,
                                       int groups) {
    int total = B * C * H * W;
    int channels_per_group = C / groups;
    int HW = H * W;
    int CHW = C * HW;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    for (int index = idx; index < total; index += stride) {
        int b = index / CHW;
        int rem = index % CHW;
        int c = rem / HW;
        int pos = rem % HW;
        int h = pos / W;
        int w = pos % W;
        
        int i = c / groups;
        int g = c % groups;
        int src_c = g * channels_per_group + i;
        
        int src_index = ((b * C + src_c) * HW) + pos;
        output[index] = input[src_index];
    }
}

// Inline device function to decompose a linear index for the vectorized kernel.
// Assumes input tensor shape [B, C, H, W] vectorized along W (W_div4 = W/4)
__forceinline__ __device__ void decompose_index_vec(int index, int C, int H, int WH, int &b, int &c, int &h, int &w_vec) {
    int tot_per_batch = C * H * WH; // WH here is W/4.
    b = index / tot_per_batch;
    int rem = index % tot_per_batch;
    int HWH = H * WH;
    c = rem / HWH;
    int rem2 = rem % HWH;
    h = rem2 / WH;
    w_vec = rem2 % WH;
}

// Vectorized kernel for channel shuffle, using float4 loads and stores.
// 'WH' is W/4.
__global__ void channel_shuffle_vec_kernel(const float4* __restrict__ input,
                                           float4* __restrict__ output,
                                           int B, int C, int H, int WH,
                                           int groups) {
    int total = B * C * H * WH;
    int channels_per_group = C / groups;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    for (int index = idx; index < total; index += stride) {
        int b, c, h, w_vec;
        decompose_index_vec(index, C, H, WH, b, c, h, w_vec);
        
        int i = c / groups;
        int g = c % groups;
        int src_c = g * channels_per_group + i;
        
        int src_index = ((b * C + src_c) * H + h) * WH + w_vec;
        output[index] = input[src_index];
    }
}

torch::Tensor channel_shuffle(torch::Tensor input, int groups) {
    c10::cuda::CUDAGuard device_guard(input.device());
    int B = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    
    auto output = torch::empty_like(input);
    
    if (W % 4 == 0) {
        int WH = W / 4;
        int total = B * C * H * WH;
        int threads = 256;
        int blocks = (total + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        channel_shuffle_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input_vec, output_vec, B, C, H, WH, groups);
    } else {
        int total = B * C * H * W;
        int threads = 256;
        int blocks = (total + threads - 1) / threads;
        channel_shuffle_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(), output.data_ptr<float>(), B, C, H, W, groups);
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("channel_shuffle", &channel_shuffle, "Custom optimized CUDA kernel for channel shuffle (with vectorized support and forced inlining)");
}