3.25. ShuffleNet Unit
Calculates a feature transformation on an input tensor by first applying a grouped 1x1 convolution to reduce dimensionality, then a depthwise 3x3 convolution for spatial filtering. It then performs a channel shuffle operation to mix channels across groups and applies a second grouped 1x1 convolution, integrating batch normalization, ReLU activations, and a shortcut connection to preserve input identity when applicable.
Progress
Optimization Timeline
Step 5 Details
Reasoning
Our current implementation achieves nearly the same performance as the reference (reward ≈1.006), which indicates that our custom channel shuffle kernel is already extremely optimized with vectorized loads using float4 and an inlined index decomposition function. One possible micro-optimization is to force inlining the index decomposition function by adding __forceinline__ (or similar) to help the compiler optimize better. However, further improvements beyond that are likely to provide diminishing returns, because the operation is memory-bound and our arithmetic overhead is already minimal. Therefore, we simply add __forceinline__ to our decompose_index_vec function and leave the rest unchanged. This does not change correctness and may help achieve a slightly improved speedup.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Scalar kernel for channel shuffle (if width is not divisible by 4)
__global__ void channel_shuffle_kernel(const float* __restrict__ input,
float* __restrict__ output,
int B, int C, int H, int W,
int groups) {
int total = B * C * H * W;
int channels_per_group = C / groups;
int HW = H * W;
int CHW = C * HW;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int index = idx; index < total; index += stride) {
int b = index / CHW;
int rem = index % CHW;
int c = rem / HW;
int pos = rem % HW;
int h = pos / W;
int w = pos % W;
int i = c / groups;
int g = c % groups;
int src_c = g * channels_per_group + i;
int src_index = ((b * C + src_c) * HW) + pos;
output[index] = input[src_index];
}
}
// Inline device function to decompose a linear index for the vectorized kernel.
// Assumes input tensor shape [B, C, H, W] vectorized along W (W_div4 = W/4)
__forceinline__ __device__ void decompose_index_vec(int index, int C, int H, int WH, int &b, int &c, int &h, int &w_vec) {
int tot_per_batch = C * H * WH; // WH here is W/4.
b = index / tot_per_batch;
int rem = index % tot_per_batch;
int HWH = H * WH;
c = rem / HWH;
int rem2 = rem % HWH;
h = rem2 / WH;
w_vec = rem2 % WH;
}
// Vectorized kernel for channel shuffle, using float4 loads and stores.
// 'WH' is W/4.
__global__ void channel_shuffle_vec_kernel(const float4* __restrict__ input,
float4* __restrict__ output,
int B, int C, int H, int WH,
int groups) {
int total = B * C * H * WH;
int channels_per_group = C / groups;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int index = idx; index < total; index += stride) {
int b, c, h, w_vec;
decompose_index_vec(index, C, H, WH, b, c, h, w_vec);
int i = c / groups;
int g = c % groups;
int src_c = g * channels_per_group + i;
int src_index = ((b * C + src_c) * H + h) * WH + w_vec;
output[index] = input[src_index];
}
}
torch::Tensor channel_shuffle(torch::Tensor input, int groups) {
c10::cuda::CUDAGuard device_guard(input.device());
int B = input.size(0);
int C = input.size(1);
int H = input.size(2);
int W = input.size(3);
auto output = torch::empty_like(input);
if (W % 4 == 0) {
int WH = W / 4;
int total = B * C * H * WH;
int threads = 256;
int blocks = (total + threads - 1) / threads;
const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
channel_shuffle_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_vec, output_vec, B, C, H, WH, groups);
} else {
int total = B * C * H * W;
int threads = 256;
int blocks = (total + threads - 1) / threads;
channel_shuffle_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), output.data_ptr<float>(), B, C, H, W, groups);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("channel_shuffle", &channel_shuffle, "Custom optimized CUDA kernel for channel shuffle (with vectorized support and forced inlining)");
}