2.26. 3D Transposed Convolution with Addition and HardSwish

Performs a 3D transposed convolution on an input tensor using adjustable stride, padding, and output padding parameters, then adds a secondary tensor element-wise, and finally computes the output by multiplying the result with its HardSwish activation.

Progress

Step 1 of 1
Reward
1.53x

Optimization Timeline

Step 1 Details

Reasoning

In this problem we want to accelerate a model that performs a 3D transposed convolution followed by an element‐wise addition and a HardSwish activation. The transposed convolution is already optimized in PyTorch, so we focus on fusing the subsequent two operations into one custom CUDA kernel. Recall that the operations we want to fuse are:  • Let z = conv_transpose(x) + add_input  • Compute HardSwish via torch.nn.functional.hardswish(z) which is defined as (z * relu6(z+3))/6.  • Then multiply element‐wise: output = z * hardswish(z) = z * (z * relu6(z+3)/6) = (z² * relu6(z+3))/6. By fusing these steps in one kernel we can reduce reads/writes from global memory. Our CUDA kernel (named “fused_add_hardswish_kernel”) will:  – Launch enough threads to cover the entire output tensor.  – For each element, compute z = conv[z] + add[z], then compute relu6 = min(max(z+3, 0),6), and finally write output[i] = (z*z*relu6)/6. We use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() for proper CUDA device management and kernel launch. The Python module wraps a standard ConvTranspose3d layer and then calls our fused CUDA operator. We load our extension using torch.utils.cpp_extension.load with optimization flags –O3 and --use_fast_math.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// CUDA kernel to compute fused addition and HardSwish activation.
// For each element, compute:
//    z = conv_out[i] + add_input[i]
//    relu6 = min(max(z+3, 0), 6)
//    output[i] = (z*z*relu6)/6.
__global__ void fused_add_hardswish_kernel(const float* __restrict__ conv_out,
                                           const float* __restrict__ add_input,
                                           float* __restrict__ output,
                                           int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    // Grid-stride loop to cover all elements.
    for (int i = idx; i < total; i += blockDim.x * gridDim.x) {
        float z = conv_out[i] + add_input[i];
        float relu6 = fminf(fmaxf(z + 3.0f, 0.0f), 6.0f);
        output[i] = (z * z * relu6) / 6.0f;
    }
}

torch::Tensor fused_forward(torch::Tensor conv_out, torch::Tensor add_input) {
    // Use CUDAGuard to ensure we are on the right device.
    c10::cuda::CUDAGuard device_guard(conv_out.device());
    
    // Check that tensors are 5D and have the same shape.
    TORCH_CHECK(conv_out.dim() == 5, "conv_out must be a 5D tensor");
    TORCH_CHECK(add_input.dim() == 5, "add_input must be a 5D tensor");
    TORCH_CHECK(conv_out.sizes() == add_input.sizes(), "Input tensors must have the same shape");
    
    // Allocate output tensor with the same shape and options as conv_out.
    auto output = torch::empty_like(conv_out);
    int total = output.numel();
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    // Launch the kernel on the current CUDA stream.
    fused_add_hardswish_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        conv_out.data_ptr<float>(),
        add_input.data_ptr<float>(),
        output.data_ptr<float>(),
        total);
        
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_forward", &fused_forward, "Fused elementwise addition and HardSwish activation");
}