2.79. 3D Conv-InstanceNorm with Scaling and Clamping

Calculates a sequence of operations on a 3D tensor: performs a 3D convolution, scales the result element-wise with a learnable parameter, applies instance normalization, clamps the values within a specified range, multiplies again by the parameter, and finally reduces the tensor by taking the maximum value along one dimension.

Progress

Step 2 of 2
Reward
1.39x

Optimization Timeline

Step 2 Details

Reasoning

We can slightly optimize our fused kernel further by unrolling the loop over the channel dimension and using __ldg to load the multiplier value from global memory using the read-only cache. In our kernel each thread processes a single output spatial location and loops over the channel dimension (which is usually small, e.g. 16). The loop is unrolled with a pragma for better performance, and we replace the conditional branch with an intrinsic fmaxf call. We also store the clamp_min and clamp_max values into local registers. These micro-optimizations help reduce branch divergence and improve memory throughput. The Python module remains nearly unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Fused kernel: for each sample n and spatial location (d, h, w) in the input tensor,
// iterate over the channel dimension (C) and compute:
//    val = clamp(x[n,c,d,h,w], clamp_min, clamp_max)
//    prod = val * multiplier[c]
// and take the maximum over channels, writing the result to the output at (n, d, h, w).
// Input x: shape (N, C, D, H, W). Multiplier is a 1D array of length C (flattened from shape (C,1,1,1)).
// Output: shape (N, D, H, W).
__global__ void fused_chnmx_kernel(const float* __restrict__ x, 
                                   const float* __restrict__ multiplier,
                                   float clamp_min, float clamp_max,
                                   float* __restrict__ output,
                                   int N, int C, int D, int H, int W) {
    // Compute a linear index over the output spatial locations.
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = N * D * H * W;
    if (index < total) {
        // Decode index into (n, d, h, w).
        int w = index % W;
        int tmp = index / W;
        int h = tmp % H;
        tmp = tmp / H;
        int d = tmp % D;
        int n = tmp / D;
        
        float max_val = -INFINITY;
        // Load clamp bounds into registers.
        float cl_min = clamp_min;
        float cl_max = clamp_max;
        
        #pragma unroll
        for (int c = 0; c < C; c++) {
            int x_idx = (((n * C + c) * D + d) * H + h) * W + w;
            float val = x[x_idx];
            // Clamp the value.
            val = fminf(fmaxf(val, cl_min), cl_max);
            // Multiply by multiplier (using __ldg to load multiplier from read-only cache).
            float prod = val * __ldg(&multiplier[c]);
            max_val = fmaxf(max_val, prod);
        }
        output[index] = max_val;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor multiplier, float clamp_min, float clamp_max) {
    c10::cuda::CUDAGuard device_guard(x.device());
    int N = x.size(0);
    int C = x.size(1);
    int D = x.size(2);
    int H = x.size(3);
    int W = x.size(4);
    
    // Output shape: (N, D, H, W)
    auto output = torch::empty({N, D, H, W}, x.options());
    
    int total = N * D * H * W;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_chnmx_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        multiplier.data_ptr<float>(),
        clamp_min,
        clamp_max,
        output.data_ptr<float>(),
        N, C, D, H, W
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused clamp, multiplication, and channel max reduction for 3D conv output");
}