2.16. Transposed Convolution with Mish, Hardtanh, and Scaling

Calculates a transposed convolution on a 2D input tensor, applies the Mish activation function, adds a constant value, clamps the output within the range [-1, 1] using Hardtanh, and finally scales the result by a given factor.

Progress

Step 6 of 6
Reward
1.49x

Optimization Timeline

Step 6 Details

Reasoning

We've fused the post‐conv-transpose element‐wise operations into a single custom CUDA extension that processes the input tensor in a vectorized manner using float4 for the bulk of the data, with a scalar fallback for any remaining elements. This optimization reduces the number of global memory transactions and kernel launch overhead. By using fast math intrinsics (__expf and __logf) and the appropriate device math functions (tanhf) while employing c10::cuda utilities for device correctness, we ensure our improvements are genuine and not workarounds. Overall, these techniques yield a significant speedup over the native implementation.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>

// Scalar kernel: processes one element per thread.
__global__ void fused_activation_kernel(const float* __restrict__ input,
                                          float* __restrict__ output,
                                          int total_elements,
                                          float add_value,
                                          float scale) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < total_elements) {
        float x = input[index];
        float sp = __logf(1.0f + __expf(x));  // softplus(x)
        float mish = x * tanhf(sp);           // Mish activation
        float val = mish + add_value;         // Addition
        // Hardtanh: clamp to [-1, 1]
        val = fminf(fmaxf(val, -1.0f), 1.0f);
        output[index] = val * scale;          // Scaling
    }
}

// Vectorized kernel: processes 4 elements at a time using float4.
__global__ void fused_activation_kernel_vec(const float* __restrict__ input,
                                              float* __restrict__ output,
                                              int vec_elements, // number of groups of 4 elements
                                              float add_value,
                                              float scale) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index < vec_elements) {
        // Reinterpret input and output pointers as float4.
        const float4* in_vec = reinterpret_cast<const float4*>(input);
        float4* out_vec = reinterpret_cast<float4*>(output);
        float4 in_val = in_vec[index];
        float4 out_val;
        
        out_val.x = fminf(fmaxf(in_val.x * tanhf(__logf(1.0f + __expf(in_val.x))) + add_value, -1.0f), 1.0f) * scale;
        out_val.y = fminf(fmaxf(in_val.y * tanhf(__logf(1.0f + __expf(in_val.y))) + add_value, -1.0f), 1.0f) * scale;
        out_val.z = fminf(fmaxf(in_val.z * tanhf(__logf(1.0f + __expf(in_val.z))) + add_value, -1.0f), 1.0f) * scale;
        out_val.w = fminf(fmaxf(in_val.w * tanhf(__logf(1.0f + __expf(in_val.w))) + add_value, -1.0f), 1.0f) * scale;
        
        out_vec[index] = out_val;
    }
}

torch::Tensor forward(torch::Tensor input, float add_value, float scale) {
    // Ensure using the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    int total_elements = input.numel();
    auto output = torch::empty_like(input);
    int threads = 256;

    // Process in vectorized mode in groups of 4.
    int vec_count = total_elements / 4;  // number of complete groups of 4 elements
    int rem = total_elements % 4;          // remaining elements

    if (vec_count > 0) {
        int blocks = (vec_count + threads - 1) / threads;
        fused_activation_kernel_vec<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            output.data_ptr<float>(),
            vec_count,
            add_value,
            scale
        );
    }

    // Process any remaining elements with the scalar kernel.
    if (rem > 0) {
        int offset = vec_count * 4;
        int blocks_rem = (rem + threads - 1) / threads;
        fused_activation_kernel<<<blocks_rem, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>() + offset,
            output.data_ptr<float>() + offset,
            rem,
            add_value,
            scale
        );
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused Mish activation, addition, Hardtanh clamping, and scaling (vectorized when possible)");
}