2.27. 3D Conv with Activation and Aggregation
Performs a 3D convolution on volumetric input data, then applies a HardSwish activation followed by ReLU, normalizes the output with a Softmax across channels, and finally computes the mean over the spatial dimensions.
Progress
Optimization Timeline
Step 3 Details
Reasoning
Previously, the fused kernel reduced each sample’s softmax values across spatial positions using a shared memory loop and atomic adds. Although this gave a speedup, it still incurred overhead from shared memory reductions. To optimize further, we assign one block per sample and then use warp-level reductions (via __shfl_down_sync) on the per-thread partial sums stored in registers. Each thread iterates over its assigned spatial positions computing a local accumulation of softmax values (after applying HardSwish and ReLU), then uses efficient warp shuffles to reduce within a warp. The warp leaders then write their results to shared memory and a final reduction across warps is performed by thread 0. This approach minimizes shared memory and synchronization overhead, and should yield even lower runtimes than the previous version. The Python interface remains unchanged.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#define MAX_CHANNELS 128 // Maximum channels supported; adjust as needed.
#define WARP_SIZE 32
// Optimized fused kernel: one block per sample.
// Each thread processes a subset of spatial positions and accumulates softmax values (after applying HardSwish and ReLU)
// for each channel into a register array. Then, warp-level reductions using __shfl_down_sync are applied,
// and warp leaders write to shared memory, where a final reduction is performed.
extern "C" __global__ void fused_kernel(const float* __restrict__ input,
float* __restrict__ output,
int B, int C, int S) {
// Each block processes one sample.
int b = blockIdx.x;
int tid = threadIdx.x;
int blockSize = blockDim.x;
// Each thread will accumulate softmax sums over its assigned spatial positions.
float local_sum[MAX_CHANNELS];
for (int c = 0; c < C; c++) {
local_sum[c] = 0.0f;
}
// Pointer to sample data in NCDHW layout, where S = D*H*W.
const float* sample_ptr = input + b * (C * S);
// Iterate over spatial positions using grid-stride loop.
for (int i = tid; i < S; i += blockSize) {
float act[MAX_CHANNELS];
// For each channel, load input and apply HardSwish then ReLU.
for (int c = 0; c < C; c++) {
float x = sample_ptr[c * S + i];
// HardSwish: x * relu6(x+3)/6.
float hs = x * fminf(fmaxf(x + 3.0f, 0.0f), 6.0f) / 6.0f;
act[c] = fmaxf(hs, 0.0f);
}
// Compute maximum across channels for numerical stability.
float max_val = act[0];
for (int c = 1; c < C; c++) {
if (act[c] > max_val) max_val = act[c];
}
// Compute exponentials and sum.
float sum_exp = 0.0f;
for (int c = 0; c < C; c++) {
act[c] = expf(act[c] - max_val);
sum_exp += act[c];
}
// Compute softmax for each channel and accumulate.
for (int c = 0; c < C; c++) {
float softmax_val = act[c] / sum_exp;
local_sum[c] += softmax_val;
}
} // end spatial loop
// Warp-level reduction using shuffle operations.
unsigned int mask = 0xffffffff; // full mask
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
for (int c = 0; c < C; c++) {
local_sum[c] += __shfl_down_sync(mask, local_sum[c], offset);
}
}
// Allocate shared memory for warp-level partial sums.
// Shared memory size should be (number of warps) * C floats.
extern __shared__ float shared[];
int warpId = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
if (lane == 0) {
for (int c = 0; c < C; c++) {
shared[warpId * C + c] = local_sum[c];
}
}
__syncthreads();
// Final reduction across warps: done by thread 0.
if (tid == 0) {
int warpCount = (blockSize + WARP_SIZE - 1) / WARP_SIZE;
for (int c = 0; c < C; c++) {
float sum_result = 0.0f;
for (int w = 0; w < warpCount; w++) {
sum_result += shared[w * C + c];
}
output[b * C + c] = sum_result / ((float) S);
}
}
}
torch::Tensor fused_forward(torch::Tensor input) {
// Ensure proper CUDA device usage.
c10::cuda::CUDAGuard device_guard(input.device());
// Input: [B, C, D, H, W]
int B = input.size(0);
int C = input.size(1);
int D = input.size(2);
int H = input.size(3);
int W = input.size(4);
int S = D * H * W;
// Ensure contiguous memory layout.
auto input_contig = input.contiguous();
// Allocate output tensor of shape [B, C].
auto output = torch::empty({B, C}, input.options());
// Launch one block per sample.
int threads = 256;
int blocks = B;
int warpCount = (threads + WARP_SIZE - 1) / WARP_SIZE;
// Shared memory required: warpCount * C * sizeof(float).
size_t shared_size = warpCount * C * sizeof(float);
fused_kernel<<<blocks, threads, shared_size, c10::cuda::getCurrentCUDAStream()>>>(
input_contig.data_ptr<float>(),
output.data_ptr<float>(),
B, C, S
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_forward", &fused_forward, "Fused 3D conv chain: HardSwish, ReLU, Softmax (dim=1), and spatial mean using warp-level reduction");
}