2.79. 3D Conv-InstanceNorm with Scaling and Clamping
Calculates a sequence of operations on a 3D tensor: performs a 3D convolution, scales the result element-wise with a learnable parameter, applies instance normalization, clamps the values within a specified range, multiplies again by the parameter, and finally reduces the tensor by taking the maximum value along one dimension.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can slightly optimize our fused kernel further by unrolling the loop over the channel dimension and using __ldg to load the multiplier value from global memory using the read-only cache. In our kernel each thread processes a single output spatial location and loops over the channel dimension (which is usually small, e.g. 16). The loop is unrolled with a pragma for better performance, and we replace the conditional branch with an intrinsic fmaxf call. We also store the clamp_min and clamp_max values into local registers. These micro-optimizations help reduce branch divergence and improve memory throughput. The Python module remains nearly unchanged.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Fused kernel: for each sample n and spatial location (d, h, w) in the input tensor,
// iterate over the channel dimension (C) and compute:
// val = clamp(x[n,c,d,h,w], clamp_min, clamp_max)
// prod = val * multiplier[c]
// and take the maximum over channels, writing the result to the output at (n, d, h, w).
// Input x: shape (N, C, D, H, W). Multiplier is a 1D array of length C (flattened from shape (C,1,1,1)).
// Output: shape (N, D, H, W).
__global__ void fused_chnmx_kernel(const float* __restrict__ x,
const float* __restrict__ multiplier,
float clamp_min, float clamp_max,
float* __restrict__ output,
int N, int C, int D, int H, int W) {
// Compute a linear index over the output spatial locations.
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = N * D * H * W;
if (index < total) {
// Decode index into (n, d, h, w).
int w = index % W;
int tmp = index / W;
int h = tmp % H;
tmp = tmp / H;
int d = tmp % D;
int n = tmp / D;
float max_val = -INFINITY;
// Load clamp bounds into registers.
float cl_min = clamp_min;
float cl_max = clamp_max;
#pragma unroll
for (int c = 0; c < C; c++) {
int x_idx = (((n * C + c) * D + d) * H + h) * W + w;
float val = x[x_idx];
// Clamp the value.
val = fminf(fmaxf(val, cl_min), cl_max);
// Multiply by multiplier (using __ldg to load multiplier from read-only cache).
float prod = val * __ldg(&multiplier[c]);
max_val = fmaxf(max_val, prod);
}
output[index] = max_val;
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor multiplier, float clamp_min, float clamp_max) {
c10::cuda::CUDAGuard device_guard(x.device());
int N = x.size(0);
int C = x.size(1);
int D = x.size(2);
int H = x.size(3);
int W = x.size(4);
// Output shape: (N, D, H, W)
auto output = torch::empty({N, D, H, W}, x.options());
int total = N * D * H * W;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_chnmx_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
multiplier.data_ptr<float>(),
clamp_min,
clamp_max,
output.data_ptr<float>(),
N, C, D, H, W
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused clamp, multiplication, and channel max reduction for 3D conv output");
}