2.29. Matrix Multiplication and Double Mish Activation

Performs a linear transformation through matrix multiplication on input data, then calculates the Mish activation function twice in sequence to transform the resulting values.

Progress

Step 2 of 2
Reward
1.30x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize the fused double-Mish kernel by factoring out the Mish computation into an inline device function. This allows the compiler to better optimize and in-line the repeated computation and may yield additional speed improvements. In the updated kernel, we define a __device__ inline function mish_activation which computes:   mish(x) = x * tanh( log1pf(exp(x)) ) using fast intrinsic functions (i.e. __expf, log1pf, and tanhf). Then the kernel simply calls mish_activation twice per element. This minimizes redundant code and may improve instruction-level parallelism, thus reducing runtime compared to our previous kernel. The Python module remains essentially the same: it loads the extension and calls the fused operator in the forward method. Have my improvements genuinely helped? Yes, by inlining the mish computation, we reduce overhead and allow the compiler to perform more optimizations that can speed up the execution over two separate mish kernel calls.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Inline device function to compute Mish activation.
// Computes mish(x) = x * tanh( log1pf(exp(x)) )
__device__ __forceinline__ float mish_activation(float x) {
    float sp = log1pf(__expf(x));
    return x * tanhf(sp);
}

// Fused kernel that applies Mish activation twice to each element in the input tensor.
// For each element x, compute:
//    m1 = mish_activation(x)
//    m2 = mish_activation(m1)
// and store m2 in the output.
__global__ void double_mish_kernel(const float* __restrict__ input,
                                   float* __restrict__ output,
                                   int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    for (int i = idx; i < total; i += stride) {
        float x = input[i];
        float m1 = mish_activation(x);
        float m2 = mish_activation(m1);
        output[i] = m2;
    }
}

torch::Tensor double_mish(torch::Tensor input) {
    // Ensure the operations run on the correct CUDA device.
    c10::cuda::CUDAGuard guard(input.device());
    
    int total = input.numel();
    auto output = torch::empty_like(input);
    
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    double_mish_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        total);
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("double_mish", &double_mish, "Fused CUDA kernel applying Mish activation twice with inline optimization");
}