2.25. Conv-Min-Tanh Operation
Calculates a 2D convolution on input data using a specified kernel, computes the minimum value across the channel dimension, and then applies the hyperbolic tangent activation function twice.
Progress
Optimization Timeline
Step 4 Details
Reasoning
We can optimize further by fine-tuning the inner loop of our fused kernel. In our revised version we add constant qualifiers for invariants (such as the spatial size) and use the "#pragma unroll" directive to encourage loop unrolling over the channel dimension. Since the channel count is typically small, unrolling can reduce loop overhead and improve instruction throughput. The modifications do not alter the mathematical correctness but should reduce per-thread overhead and improve overall latency. The Python module remains largely unchanged, serving mainly to load the CUDA extension with the proper optimization flags.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Fused kernel: reduces input tensor x of shape [batch, channels, H, W]
// along the channel dimension by computing the minimum value, then applies two successive tanh operations.
// The output y has shape [batch, 1, H, W].
__global__ void fused_min_tanh_kernel(const float* __restrict__ x,
float* __restrict__ y,
const int batch,
const int channels,
const int H,
const int W) {
const int spatial = H * W;
const int total = batch * spatial;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
const int b = idx / spatial;
const int pixel = idx % spatial;
const int h = pixel / W;
const int w = pixel % W;
const int base = b * channels * spatial;
const int idx0 = base + h * W + w;
float min_val = x[idx0];
#pragma unroll
for (int c = 1; c < channels; c++) {
int index = base + c * spatial + h * W + w;
float val = x[index];
min_val = fminf(min_val, val);
}
float tmp = tanhf(min_val);
float activated = tanhf(tmp);
y[b * spatial + h * W + w] = activated;
}
}
torch::Tensor forward(torch::Tensor x) {
// Ensure the computation uses the proper CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
const int batch = x.size(0);
const int channels = x.size(1);
const int H = x.size(2);
const int W = x.size(3);
const int spatial = H * W;
// Allocate output tensor with shape [batch, 1, H, W].
auto y = torch::empty({batch, 1, H, W}, x.options());
const int total = batch * spatial;
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
fused_min_tanh_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
y.data_ptr<float>(),
batch,
channels,
H,
W
);
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused min reduction across channels with double tanh activation");
}