2.46. Convolution Subtraction Tanh Average Pooling

Calculates a 2D convolution on input data, subtracts a constant, applies a tanh activation, subtracts another constant, and then averages the spatial regions using average pooling.

Progress

Step 3 of 3
Reward
1.67x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize the fused CUDA kernel by using the __ldg intrinsic when reading from the convolution output. This improves caching and reduces latency for read-only data. In addition, in the generic pooling branch we precompute the constant division factor (1/(pool_size^2)) to avoid recomputing it in each iteration. For the common case when pool_size==2, our unrolled branch now uses __ldg to load input values, ensuring that these loads are cached. These optimizations reduce memory latency and overhead in the pooling loop without altering the mathematical result.

// Include necessary headers.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel: Given input tensor (output of conv2d) with shape
// [batch, channels, H_conv, W_conv], perform for each pooling window of size (pool_size x pool_size):
// 1. Subtract subtract1 from each element.
// 2. Apply tanhf.
// 3. Subtract subtract2.
// 4. Compute average over the window.
// Uses __ldg for reading input values to improve caching.
__global__ void fused_subtanh_pool_kernel(const float* __restrict__ input,
                                          float* __restrict__ output,
                                          int batch,
                                          int channels,
                                          int H_conv,
                                          int W_conv,
                                          int pool_size,
                                          float subtract1,
                                          float subtract2) {
    // Compute pooled output dimensions.
    int H_pool = H_conv / pool_size;
    int W_pool = W_conv / pool_size;
    int out_total = batch * channels * H_pool * W_pool;
    
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    if (index >= out_total)
        return;
    
    // Decode output index into (b, c, h_pool, w_pool).
    int w_pool = index % W_pool;
    int tmp = index / W_pool;
    int h_pool = tmp % H_pool;
    int tmp2 = tmp / H_pool;
    int c = tmp2 % channels;
    int b = tmp2 / channels;
    
    // Starting indices in the convolution output.
    int h_start = h_pool * pool_size;
    int w_start = w_pool * pool_size;
    
    if (pool_size == 2) {
        // Optimized unrolled version for a 2x2 pooling window.
        int base = ((b * channels + c) * H_conv + h_start) * W_conv + w_start;
        float v0 = __ldg(&input[base]);
        float v1 = __ldg(&input[base + 1]);
        float v2 = __ldg(&input[base + W_conv]);
        float v3 = __ldg(&input[base + W_conv + 1]);
        
        v0 = tanhf(v0 - subtract1) - subtract2;
        v1 = tanhf(v1 - subtract1) - subtract2;
        v2 = tanhf(v2 - subtract1) - subtract2;
        v3 = tanhf(v3 - subtract1) - subtract2;
        
        float avg = (v0 + v1 + v2 + v3) * 0.25f;
        output[index] = avg;
    } else {
        // Generic version for any pooling window size.
        float sum = 0.0f;
        // Precompute factor for averaging.
        float factor = 1.0f / (pool_size * pool_size);
        for (int i = 0; i < pool_size; i++) {
            for (int j = 0; j < pool_size; j++) {
                int h_idx = h_start + i;
                int w_idx = w_start + j;
                int input_index = ((b * channels + c) * H_conv + h_idx) * W_conv + w_idx;
                float val = __ldg(&input[input_index]);
                sum += tanhf(val - subtract1) - subtract2;
            }
        }
        output[index] = sum * factor;
    }
}

torch::Tensor forward(torch::Tensor input,
                      float subtract1,
                      float subtract2,
                      int pool_size) {
    // Enforce operations on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(input.device());
    
    // Input shape: [batch, channels, H_conv, W_conv].
    int batch = input.size(0);
    int channels = input.size(1);
    int H_conv = input.size(2);
    int W_conv = input.size(3);
    
    // Determine pooling output dimensions.
    int H_pool = H_conv / pool_size;
    int W_pool = W_conv / pool_size;
    
    // Allocate output tensor with shape [batch, channels, H_pool, W_pool].
    auto output = torch::empty({batch, channels, H_pool, W_pool}, input.options());
    
    int total = batch * channels * H_pool * W_pool;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_subtanh_pool_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        batch,
        channels,
        H_conv,
        W_conv,
        pool_size,
        subtract1,
        subtract2
    );
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused subtraction, tanh, subtraction and average pooling (optimized with __ldg and unrolling for pool_size==2)");
}