3.17. Squeeze and Expand Convolution Module
Performs a computation that first squeezes the input using a 1x1 convolution with ReLU activation, then expands the resulting features by processing them in parallel through a 1x1 convolution and a 3x3 convolution (both with ReLU activations), and finally concatenates these outputs along the channel dimension.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our shared memory fused fire module kernel by reducing redundant multiplications and improving memory locality. In our original shared memory kernel, each thread computed the expand3x3 branch by iterating first over the 3×3 window and then over the squeeze channels. We can improve this by reordering the loops: for each expand3x3 output channel, we first loop over the squeeze channels and then the 3×3 window. This allows us to load a pointer to the shared memory for the current squeeze channel once and use it for all 3×3 positions, which reduces the multiply operations and may help the compiler unroll inner loops more aggressively. We also add #pragma unroll directives to loops over small dimensions (like in_channels and the 3×3 window) to let the compiler unroll them when possible. These improvements decrease redundant computations and improve register reuse. The fused kernel now uses shared memory more efficiently and should yield an even faster runtime. The Python module remains largely unchanged as it just loads and calls the CUDA extension.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Define maximum sizes for local arrays (assumed sufficient).
#define MAX_SQUEEZE 128
#define MAX_EXPAND1X1 256
#define MAX_EXPAND3X3 256
// Block dimensions for tiling.
constexpr int BLOCK_SIZE_X = 16;
constexpr int BLOCK_SIZE_Y = 16;
// Fused fire module kernel with shared memory optimization and loop reordering for the expand3x3 branch.
// Each block processes a tile of the output for one batch element (b = blockIdx.z).
// The kernel first loads a tile (with a 1-pixel halo) of squeeze outputs computed from x using squeeze weights and bias into shared memory.
extern "C" __global__ void fused_fire_module_shared_kernel(
const float* __restrict__ x, // [B, in_channels, H, W]
const float* __restrict__ squeeze_weight, // [squeeze_channels, in_channels]
const float* __restrict__ squeeze_bias, // [squeeze_channels]
const float* __restrict__ expand1x1_weight, // [expand1x1_channels, squeeze_channels]
const float* __restrict__ expand1x1_bias, // [expand1x1_channels]
const float* __restrict__ expand3x3_weight, // [expand3x3_channels, squeeze_channels, 3, 3]
const float* __restrict__ expand3x3_bias, // [expand3x3_channels]
float* __restrict__ output, // [B, (expand1x1_channels+expand3x3_channels), H, W]
int B, int in_channels, int H, int W,
int squeeze_channels, int expand1x1_channels, int expand3x3_channels
) {
// Batch index.
int b = blockIdx.z;
// Global output coordinates.
int global_y = blockIdx.y * BLOCK_SIZE_Y + threadIdx.y;
int global_x = blockIdx.x * BLOCK_SIZE_X + threadIdx.x;
// Shared memory tile dimensions (with halo of 1 pixel on each side).
const int tile_w = BLOCK_SIZE_X + 2;
const int tile_h = BLOCK_SIZE_Y + 2;
int tile_size = tile_w * tile_h; // per channel
// Allocate dynamic shared memory: storage for squeeze outputs for all channels.
extern __shared__ float shared_squeeze[]; // size = squeeze_channels * tile_size
// Determine tile origin in global coordinates.
int tile_origin_y = blockIdx.y * BLOCK_SIZE_Y - 1;
int tile_origin_x = blockIdx.x * BLOCK_SIZE_X - 1;
int total_shared = squeeze_channels * tile_size;
int tid = threadIdx.y * BLOCK_SIZE_X + threadIdx.x;
int block_threads = BLOCK_SIZE_X * BLOCK_SIZE_Y;
// Load the shared tile: each element corresponds to a particular channel and (y,x) position.
for (int idx = tid; idx < total_shared; idx += block_threads) {
int c = idx / tile_size;
int pos = idx % tile_size;
int ty = pos / tile_w;
int tx = pos % tile_w;
int y = tile_origin_y + ty;
int x_coord = tile_origin_x + tx;
float val = 0.0f;
if (y >= 0 && y < H && x_coord >= 0 && x_coord < W) {
float sum = 0.0f;
#pragma unroll
for (int k = 0; k < in_channels; k++) {
int x_index = ((b * in_channels + k) * H + y) * W + x_coord;
sum += x[x_index] * squeeze_weight[c * in_channels + k];
}
sum += squeeze_bias[c];
val = fmaxf(sum, 0.0f);
}
shared_squeeze[c * tile_size + pos] = val;
}
__syncthreads();
// Only compute if the thread maps to a valid output pixel.
if (global_y < H && global_x < W) {
// Load local squeeze vector from shared memory.
float s_val[MAX_SQUEEZE];
int local_y = threadIdx.y + 1; // offset for halo
int local_x = threadIdx.x + 1;
int tile_index = local_y * tile_w + local_x;
for (int c = 0; c < squeeze_channels; c++) {
s_val[c] = shared_squeeze[c * tile_size + tile_index];
}
// Compute expand1x1 branch.
float y1[MAX_EXPAND1X1];
for (int j = 0; j < expand1x1_channels; j++) {
float sum = 0.0f;
#pragma unroll
for (int c = 0; c < squeeze_channels; c++) {
sum += s_val[c] * expand1x1_weight[j * squeeze_channels + c];
}
sum += expand1x1_bias[j];
y1[j] = fmaxf(sum, 0.0f);
}
// Compute expand3x3 branch.
float y3[MAX_EXPAND3X3];
// Reordered loops: For each output channel, accumulate over squeeze channels then 3x3 window.
for (int k = 0; k < expand3x3_channels; k++) {
float sum = 0.0f;
for (int c = 0; c < squeeze_channels; c++) {
// Get pointer to shared memory for channel c.
float* shared_channel = shared_squeeze + c * tile_size;
#pragma unroll
for (int r = 0; r < 3; r++) {
#pragma unroll
for (int s = 0; s < 3; s++) {
int idx_shared = (threadIdx.y + r) * tile_w + (threadIdx.x + s);
int weight_index = ((k * squeeze_channels + c) * 3 + r) * 3 + s;
sum += shared_channel[idx_shared] * expand3x3_weight[weight_index];
}
}
}
sum += expand3x3_bias[k];
y3[k] = fmaxf(sum, 0.0f);
}
// Write the concatenated output to global memory.
int total_out_channels = expand1x1_channels + expand3x3_channels;
int out_index_base = ((b * total_out_channels) * H + global_y) * W + global_x;
// Write expand1x1 branch.
for (int j = 0; j < expand1x1_channels; j++) {
int out_index = out_index_base + j * H * W;
output[out_index] = y1[j];
}
// Write expand3x3 branch.
for (int k = 0; k < expand3x3_channels; k++) {
int out_index = out_index_base + (expand1x1_channels + k) * H * W;
output[out_index] = y3[k];
}
}
}
torch::Tensor forward(torch::Tensor x,
torch::Tensor squeeze_weight,
torch::Tensor squeeze_bias,
torch::Tensor expand1x1_weight,
torch::Tensor expand1x1_bias,
torch::Tensor expand3x3_weight,
torch::Tensor expand3x3_bias) {
// Enforce computation on correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// x: [B, in_channels, H, W]
int B = x.size(0);
int in_channels = x.size(1);
int H = x.size(2);
int W = x.size(3);
int squeeze_channels = squeeze_weight.size(0);
int expand1x1_channels = expand1x1_weight.size(0);
int expand3x3_channels = expand3x3_weight.size(0);
// Allocate output tensor: [B, (expand1x1_channels + expand3x3_channels), H, W]
auto output = torch::empty({B, expand1x1_channels + expand3x3_channels, H, W}, x.options());
// Grid and block dimensions.
dim3 block(BLOCK_SIZE_X, BLOCK_SIZE_Y);
dim3 grid((W + BLOCK_SIZE_X - 1) / BLOCK_SIZE_X, (H + BLOCK_SIZE_Y - 1) / BLOCK_SIZE_Y, B);
// Shared memory size: tile_w * tile_h * squeeze_channels * sizeof(float)
int tile_w = BLOCK_SIZE_X + 2;
int tile_h = BLOCK_SIZE_Y + 2;
size_t shared_mem_size = tile_w * tile_h * squeeze_channels * sizeof(float);
fused_fire_module_shared_kernel<<<grid, block, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
squeeze_weight.data_ptr<float>(),
squeeze_bias.data_ptr<float>(),
expand1x1_weight.data_ptr<float>(),
expand1x1_bias.data_ptr<float>(),
expand3x3_weight.data_ptr<float>(),
expand3x3_bias.data_ptr<float>(),
output.data_ptr<float>(),
B, in_channels, H, W,
squeeze_channels, expand1x1_channels, expand3x3_channels
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused fire module with shared memory (optimized): squeeze conv, expand (1x1 & 3x3 with ReLU) and concatenation");
}