3.26. Channel Shuffle Convolution Architecture

Calculates convolutional features by performing grouped pointwise and depthwise convolutions along with a channel shuffle operation. The network assembles multiple stages of these operations to efficiently transform image inputs into classification scores.

Progress

Step 5 of 5
Reward
1.01x

Optimization Timeline

Step 5 Details

Reasoning

After re‐analyzing the channel shuffle implementation, we conclude that our current optimized kernels (scalar and vectorized) are already very efficient. The kernel uses precomputed strides, selects the vectorized (float4) path when the width is divisible by 4, and avoids redundant computations. Further micro‐optimizations (e.g. using shared memory or warp intrinsics) are not likely to provide significant speedup given the memory-bound nature of the operation and the fact that our workload is simply reordering data. In our tests, the performance is nearly identical to the reference implementation. Thus, our improvements have genuinely helped, but we are already near the optimum. Therefore, we provide the final code with only very minor code clean‐ups.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// Scalar kernel for channel shuffle.
// Precomputed: HW = H*W, CHW = C*HW, and C_div_groups = C / groups.
__global__ void channel_shuffle_kernel(const float* __restrict__ input,
                                       float* __restrict__ output,
                                       int total, int HW, int CHW,
                                       int C, int groups, int C_div_groups) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total) return;
    int b = idx / CHW;
    int rem = idx % CHW;
    int c = rem / HW;
    int hw = rem % HW;
    int new_c = (c % groups) * C_div_groups + (c / groups);
    int input_idx = b * CHW + new_c * HW + hw;
    output[idx] = input[input_idx];
}

// Vectorized kernel using float4. Launched when width is divisible by 4.
__global__ void channel_shuffle_vec_kernel(const float4* __restrict__ input,
                                           float4* __restrict__ output,
                                           int B, int C, int H, int W,
                                           int groups, int C_div_groups) {
    int W_vec = W / 4;
    int per_channel = H * W_vec;
    int total_vec = B * C * per_channel;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total_vec) return;
    int b = idx / (C * per_channel);
    int rem = idx % (C * per_channel);
    int c = rem / per_channel;
    int s = rem % per_channel;
    int h = s / W_vec;
    int w_vec = s % W_vec;
    int new_c = (c % groups) * C_div_groups + (c / groups);
    int input_index = b * (C * per_channel) + new_c * per_channel + h * W_vec + w_vec;
    output[idx] = input[input_index];
}

torch::Tensor channel_shuffle(torch::Tensor input, int groups) {
    c10::cuda::CUDAGuard device_guard(input.device());
    TORCH_CHECK(input.dim() == 4, "Input must be a 4D tensor");
    int B = input.size(0);
    int C = input.size(1);
    int H = input.size(2);
    int W = input.size(3);
    TORCH_CHECK(C % groups == 0, "Number of channels must be divisible by groups");
    int HW = H * W;
    int CHW = C * HW;
    int C_div_groups = C / groups;
    int total = B * C * HW;
    auto output = torch::empty_like(input);
    int threads = 256;
    if (W % 4 == 0) {
        int W_vec = W / 4;
        int per_channel = H * W_vec;
        int total_vec = B * C * per_channel;
        int blocks = (total_vec + threads - 1) / threads;
        const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
        float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
        channel_shuffle_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>( 
            input_vec, output_vec, B, C, H, W, groups, C_div_groups);
    } else {
        int blocks = (total + threads - 1) / threads;
        channel_shuffle_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(), output.data_ptr<float>(), total, HW, CHW, C, groups, C_div_groups);
    }
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("channel_shuffle", &channel_shuffle, "Custom optimized Channel Shuffle (CUDA)");
}