2.82. Conv2d-Tanh-Scaling with Bias Addition and Max-Pooling
Calculates a sequence of operations on an input tensor: it first computes a 2D convolution, applies the tanh activation to the result, scales the activated output by a constant factor, adds a broadcastable bias term, and finally performs max-pooling to aggregate local features.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused kernel by specializing the case when the pooling kernel size is 2, which is the common scenario. In the generic implementation, we iterate over the pooling window with two nested loops. For pool size 2, we can unroll these loops manually to remove loop overhead and redundant index computations. In this unrolled version, we calculate the four indices corresponding to the 2×2 window, load the four values, apply the tanh activation (using fast math tanhf), scaling, and add the bias (which is loaded once per thread). Then we take the maximum of these four activated values. For other pool sizes, we use the original nested loop approach. This optimization avoids the overhead of loop control and repeated index computation when pool==2. We use c10::cuda::CUDAGuard to ensure device correctness and launch our kernel on the current CUDA stream. We compile the CUDA code with -O3 and --use_fast_math flags. The Python module remains largely the same, as it simply loads the extension and calls the kernel. These improvements genuinely lower kernel overhead in the common case and should yield an even faster runtime.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
#include <float.h>
// Fused kernel that applies tanh activation, scaling, bias addition, and max pooling.
// Input 'input' is of shape [B, C, H, W].
// 'bias' is a per-channel bias of shape [C] (broadcasted across spatial dims).
// Pooling window is of size 'pool' x 'pool' (non-overlapping).
// Output is [B, C, H_out, W_out] where H_out = H / pool and W_out = W / pool.
__global__ void fused_kernel(const float* __restrict__ input,
float* __restrict__ output,
const float* __restrict__ bias,
int B, int C, int H, int W,
int pool,
float scaling_factor) {
int H_out = H / pool;
int W_out = W / pool;
int total = B * C * H_out * W_out;
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= total) return;
// Compute indices for output
int pw = index % W_out;
int tmp = index / W_out;
int ph = tmp % H_out;
tmp = tmp / H_out;
int c = tmp % C;
int b = tmp / C;
// Calculate starting indices for pooling window in input.
int start_row = ph * pool;
int start_col = pw * pool;
float bias_val = bias[c];
float max_val = -FLT_MAX;
// Specialize for pool size 2 (common case) with loop unrolling.
if (pool == 2) {
int in_row0 = start_row;
int in_row1 = start_row + 1;
int in_col0 = start_col;
int in_col1 = start_col + 1;
int base_idx = ((b * C + c) * H);
int idx00 = (base_idx + in_row0) * W + in_col0;
int idx01 = (base_idx + in_row0) * W + in_col1;
int idx10 = (base_idx + in_row1) * W + in_col0;
int idx11 = (base_idx + in_row1) * W + in_col1;
float a = tanhf(input[idx00]) * scaling_factor + bias_val;
float b_val = tanhf(input[idx01]) * scaling_factor + bias_val;
float c_val = tanhf(input[idx10]) * scaling_factor + bias_val;
float d = tanhf(input[idx11]) * scaling_factor + bias_val;
max_val = fmaxf(fmaxf(a, b_val), fmaxf(c_val, d));
} else {
// Generic case for other pool sizes.
for (int i = 0; i < pool; ++i) {
for (int j = 0; j < pool; ++j) {
int in_row = start_row + i;
int in_col = start_col + j;
int input_idx = ((b * C + c) * H + in_row) * W + in_col;
float val = tanhf(input[input_idx]) * scaling_factor + bias_val;
max_val = fmaxf(max_val, val);
}
}
}
int output_idx = ((b * C + c) * H_out + ph) * W_out + pw;
output[output_idx] = max_val;
}
torch::Tensor forward(torch::Tensor input, int pool, float scaling_factor, torch::Tensor bias_tensor) {
c10::cuda::CUDAGuard device_guard(input.device());
int B = input.size(0);
int C = input.size(1);
int H = input.size(2);
int W = input.size(3);
int H_out = H / pool;
int W_out = W / pool;
auto output = torch::empty({B, C, H_out, W_out}, input.options());
int total = B * C * H_out * W_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
fused_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
output.data_ptr<float>(),
bias_tensor.data_ptr<float>(),
B, C, H, W,
pool,
scaling_factor
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused tanh, scaling, bias addition, and max pooling");
}