2.35. Convolution with Subtraction and Compound Activation

Performs a 2D convolution on multi-channel input data, subtracts a constant value from the resulting feature map, applies a HardSwish non-linearity, downsamples the data using max pooling, and then refines the output with a Mish activation.

Progress

Step 3 of 3
Reward
2.00x

Optimization Timeline

Step 3 Details

Reasoning

We noticed that our kernel already fuses subtraction, HardSwish, MaxPool and Mish activation, and the specialized branch for a pool_kernel_size of 2 is unrolled. One further improvement is to ensure that all reads from the input tensor are performed via the __ldg intrinsic so that they are loaded through the read‐only cache. This can reduce global memory traffic and improve performance. We update both the specialized (pool_kernel_size==2) branch and the generic branch to use __ldg for all accesses to "input". This improvement should reduce memory latency without altering the mathematical results. The Python module remains unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>

// Device helper: HardSwish activation function.
__device__ inline float hard_swish(float x) {
    float tmp = fminf(fmaxf(x + 3.0f, 0.0f), 6.0f);
    return x * tmp / 6.0f;
}

// Device helper: Mish activation function.
__device__ inline float mish(float x) {
    float softplus = logf(1.0f + expf(x));
    return x * tanhf(softplus);
}

// Fused kernel for subtraction, HardSwish, MaxPool, and Mish.
// Input: tensor from convolution of shape [B, C, H_in, W_in].
// Output: tensor of shape [B, C, H_out, W_out] where H_out = H_in / pool_kernel_size and W_out = W_in / pool_kernel_size.
__global__ void fused_chain_kernel(const float* __restrict__ input,
                                   float* __restrict__ output,
                                   float subtract_value,
                                   int pool_kernel_size,
                                   int B, int C, int H_in, int W_in) {
    // Compute output dimensions.
    int H_out = H_in / pool_kernel_size;
    int W_out = W_in / pool_kernel_size;
    // Total number of output elements.
    int total = B * C * H_out * W_out;
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid >= total) return;

    // Map linear index to 4D indices: b, c, ph, pw.
    int pw = tid % W_out;
    int tmp = tid / W_out;
    int ph = tmp % H_out;
    tmp = tmp / H_out;
    int c = tmp % C;
    int b = tmp / C;

    int row_start = ph * pool_kernel_size;
    int col_start = pw * pool_kernel_size;
    float max_val = -FLT_MAX;
    
    if (pool_kernel_size == 2) {
        // Specialized unrolled branch for pool_kernel_size == 2.
        int row0 = row_start;
        int row1 = row_start + 1;
        int col0 = col_start;
        int col1 = col_start + 1;
        int base = ((b * C + c) * H_in);
        int idx0 = (base + row0) * W_in + col0;
        int idx1 = (base + row0) * W_in + col1;
        int idx2 = (base + row1) * W_in + col0;
        int idx3 = (base + row1) * W_in + col1;
        float a0 = hard_swish(__ldg(&input[idx0]) - subtract_value);
        float a1 = hard_swish(__ldg(&input[idx1]) - subtract_value);
        float a2 = hard_swish(__ldg(&input[idx2]) - subtract_value);
        float a3 = hard_swish(__ldg(&input[idx3]) - subtract_value);
        max_val = fmaxf(fmaxf(a0, a1), fmaxf(a2, a3));
    } else {
        // Generic branch for any pool_kernel_size.
        for (int i = 0; i < pool_kernel_size; i++) {
            for (int j = 0; j < pool_kernel_size; j++) {
                int row = row_start + i;
                int col = col_start + j;
                int index = ((b * C + c) * H_in + row) * W_in + col;
                float val = __ldg(&input[index]) - subtract_value;
                float activated = hard_swish(val);
                max_val = fmaxf(max_val, activated);
            }
        }
    }
    
    // Apply Mish activation on the pooled maximum.
    float final_val = mish(max_val);

    // Write result to output tensor.
    int out_index = ((b * C + c) * H_out + ph) * W_in / pool_kernel_size + pw;
    // Alternatively, compute output index as:
    out_index = ((b * C + c) * H_out + ph) * W_out + pw;
    output[out_index] = final_val;
}

torch::Tensor forward(torch::Tensor input, float subtract_value, int pool_kernel_size) {
    // Ensure operations occur on the proper CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    int B = input.size(0);
    int C = input.size(1);
    int H_in = input.size(2);
    int W_in = input.size(3);
    
    int H_out = H_in / pool_kernel_size;
    int W_out = W_in / pool_kernel_size;
    
    auto output = torch::empty({B, C, H_out, W_out}, input.options());
    
    int total = B * C * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_chain_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        subtract_value,
        pool_kernel_size,
        B, C, H_in, W_in
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused conv chain: subtract, hard_swish, max_pool, mish");
}