2.60. 3D Transposed Convolution with Swish and HardSwish
Performs a 3D transposed convolution to upsample the input, then applies a Swish activation computed as x multiplied by sigmoid(x), follows with group normalization, and finally uses the HardSwish activation to produce the output.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our kernel by vectorizing the heavy per‐group loop. In many real cases the group size (channels per group × D × H × W) is divisible by 4, so we vectorize loads and stores using float4. In the first pass, each thread will process groups of 4 floats, computing swish (x * sigmoid(x)) for each component; it will accumulate the sum and sum-of-squares over these 4 components. For groups whose size is not divisible by 4 we fall back to the original scalar loop. Similarly, in the second pass, we use vectorized loads/stores if possible. The warp‐level reduction using __shfl_down_sync remains unchanged. This optimization reduces the number of iterations and memory transactions when possible. The Python module remains essentially the same.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel that applies Swish, GroupNorm, and HardSwish
// to the output of ConvTranspose3d.
// For input tensor x of shape [B, C, D, H, W] with C divisible by groups,
// each block processes one (sample, group).
// The algorithm:
// 1. First pass: Each thread computes swish = x * sigmoid(x)
// and accumulates partial sums and sum of squares over the group.
// The code is vectorized if (group_size % 4 == 0).
// 2. Warp-level and shared memory reduction is used to compute the mean and variance.
// 3. Second pass: Each thread normalizes its data and applies HardSwish:
// hard_swish(norm) = norm * clamp(norm+3, 0, 6)/6.
__global__ void fused_post_kernel(const float* __restrict__ input, float* __restrict__ output,
int B, int C, int D, int H, int W,
int groups, float eps) {
// Determine sample and group indices.
int block_id = blockIdx.x;
int g = block_id % groups; // group index
int b = block_id / groups; // sample index
int channels_per_group = C / groups;
int group_size = channels_per_group * D * H * W; // number of floats in the group
int sample_offset = b * (C * D * H * W);
int group_offset = g * channels_per_group * D * H * W;
int base = sample_offset + group_offset;
float partial_sum = 0.0f;
float partial_sum_sq = 0.0f;
// Determine if we can vectorize: if group_size is divisible by 4.
bool use_vec = ((group_size & 3) == 0); // equivalent to (group_size % 4 == 0)
if (use_vec) {
int vec_length = group_size / 4;
const float4* input_vec = reinterpret_cast<const float4*>(input + base);
float4* output_vec = reinterpret_cast<float4*>(output + base);
for (int i = threadIdx.x; i < vec_length; i += blockDim.x) {
float4 in_val = input_vec[i];
float4 swish_val;
// Process each component with swish = x * sigmoid(x)
swish_val.x = in_val.x * (1.0f / (1.0f + expf(-in_val.x)));
swish_val.y = in_val.y * (1.0f / (1.0f + expf(-in_val.y)));
swish_val.z = in_val.z * (1.0f / (1.0f + expf(-in_val.z)));
swish_val.w = in_val.w * (1.0f / (1.0f + expf(-in_val.w)));
output_vec[i] = swish_val;
float s = swish_val.x + swish_val.y + swish_val.z + swish_val.w;
float s_sq = swish_val.x * swish_val.x + swish_val.y * swish_val.y +
swish_val.z * swish_val.z + swish_val.w * swish_val.w;
partial_sum += s;
partial_sum_sq += s_sq;
}
} else {
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
int idx = base + i;
float val = input[idx];
float sig = 1.0f / (1.0f + expf(-val));
float swish = val * sig;
output[idx] = swish;
partial_sum += swish;
partial_sum_sq += swish * swish;
}
}
// Warp-level reduction using __shfl_down_sync.
unsigned int mask = 0xffffffff;
float sum = partial_sum;
float sum_sq = partial_sum_sq;
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(mask, sum, offset);
sum_sq += __shfl_down_sync(mask, sum_sq, offset);
}
// Each warp's lane 0 writes its reduced sum to shared memory.
__shared__ float warp_sums[32]; // sufficient for up to 1024 threads per block
__shared__ float warp_sums_sq[32];
int lane = threadIdx.x & (warpSize - 1);
int warp_id = threadIdx.x >> 5;
if (lane == 0) {
warp_sums[warp_id] = sum;
warp_sums_sq[warp_id] = sum_sq;
}
__syncthreads();
// First warp reduces the warp sums.
float block_sum = 0.0f;
float block_sum_sq = 0.0f;
int num_warps = (blockDim.x + warpSize - 1) / warpSize;
if (threadIdx.x < num_warps) {
block_sum = warp_sums[threadIdx.x];
block_sum_sq = warp_sums_sq[threadIdx.x];
}
if (threadIdx.x < warpSize) {
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
block_sum += __shfl_down_sync(mask, block_sum, offset);
block_sum_sq += __shfl_down_sync(mask, block_sum_sq, offset);
}
if (threadIdx.x == 0) {
warp_sums[0] = block_sum; // store mean sum temporarily
warp_sums_sq[0] = block_sum_sq; // store sum of squares temporarily
}
}
__syncthreads();
float mean = warp_sums[0] / group_size;
float var = warp_sums_sq[0] / group_size - mean * mean;
float inv_std = rsqrtf(var + eps);
// Second pass: normalize and apply HardSwish.
// HardSwish: hswish(x) = x * clamp(x+3, 0, 6)/6.
if (use_vec) {
int vec_length = group_size / 4;
float4* output_vec = reinterpret_cast<float4*>(output + base);
for (int i = threadIdx.x; i < vec_length; i += blockDim.x) {
float4 val = output_vec[i];
// Process and update each component.
float norm_x = (val.x - mean) * inv_std;
float hs_x = norm_x + 3.0f;
hs_x = fminf(fmaxf(hs_x, 0.0f), 6.0f);
val.x = norm_x * hs_x / 6.0f;
float norm_y = (val.y - mean) * inv_std;
float hs_y = norm_y + 3.0f;
hs_y = fminf(fmaxf(hs_y, 0.0f), 6.0f);
val.y = norm_y * hs_y / 6.0f;
float norm_z = (val.z - mean) * inv_std;
float hs_z = norm_z + 3.0f;
hs_z = fminf(fmaxf(hs_z, 0.0f), 6.0f);
val.z = norm_z * hs_z / 6.0f;
float norm_w = (val.w - mean) * inv_std;
float hs_w = norm_w + 3.0f;
hs_w = fminf(fmaxf(hs_w, 0.0f), 6.0f);
val.w = norm_w * hs_w / 6.0f;
output_vec[i] = val;
}
} else {
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
int idx = base + i;
float swish = output[idx];
float norm = (swish - mean) * inv_std;
float hs = norm + 3.0f;
hs = fminf(fmaxf(hs, 0.0f), 6.0f);
output[idx] = norm * hs / 6.0f;
}
}
}
torch::Tensor fused_forward(torch::Tensor x, int groups, float eps) {
// Ensure operations occur on the proper CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// Input tensor x: [B, C, D, H, W]
int B = x.size(0);
int C = x.size(1);
int D = x.size(2);
int H = x.size(3);
int W = x.size(4);
auto output = torch::empty_like(x);
// Launch one block per (sample, group).
int total_blocks = B * groups;
int threads = 256;
// Allocate shared memory: two arrays of size (threads/warpSize) floats.
int num_warps = (threads + 31) / 32;
size_t shared_mem = num_warps * 2 * sizeof(float);
fused_post_kernel<<<total_blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
output.data_ptr<float>(),
B, C, D, H, W, groups, eps
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_forward, "Fused Swish, GroupNorm, and HardSwish after ConvTranspose3d (vectorized optimization)");
}