2.47. 3D Convolution with Mish and Tanh Operation
Performs a 3D convolution on volumetric input data, then applies the Mish activation function to introduce non-linearity, and finally applies the Tanh transformation to the output. The operation supports configurable kernel size, stride, and padding.
Progress
Optimization Timeline
Step 5 Details
Reasoning
We can further optimize the CUDA kernel by defining an inline device function to compute the fused activation value. This function will compute the softplus, mish, and final tanh activation in one inlined step. By using an inline function, we reduce code duplication and help the compiler better optimize the computation. We then use that function in both the elementwise and vectorized kernels. This may provide a slight reduction in per-element latency and improve overall throughput, especially when combined with our vectorized kernel using float4. These improvements are genuine optimizations as they help reduce instruction overhead and allow for better inlining by the compiler.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Inline device function for fused activation computation.
// For each element v, it computes:
// sp = __logf(1.0f + __expf(v)) (softplus, using fast math intrinsics)
// mish = v * tanhf(sp)
// return tanhf(mish)
__device__ inline float fused_activation_value(float v) {
float sp = __logf(1.0f + __expf(v));
float mish = v * tanhf(sp);
return tanhf(mish);
}
// Elementwise fused activation kernel: processes one element per thread.
__global__ void fused_activation_kernel(const float* __restrict__ input,
float* __restrict__ output,
int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
output[idx] = fused_activation_value(input[idx]);
}
}
// Vectorized fused activation kernel: processes 4 elements at a time using float4.
__global__ void fused_activation_vec_kernel(const float4* __restrict__ input,
float4* __restrict__ output,
int total_vec) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_vec) {
float4 in_val = input[idx];
float4 out_val;
out_val.x = fused_activation_value(in_val.x);
out_val.y = fused_activation_value(in_val.y);
out_val.z = fused_activation_value(in_val.z);
out_val.w = fused_activation_value(in_val.w);
output[idx] = out_val;
}
}
torch::Tensor fused_activation(torch::Tensor x) {
c10::cuda::CUDAGuard device_guard(x.device());
auto output = torch::empty_like(x);
int total = x.numel();
int threads = 256;
// If total is divisible by 4 and the tensor is contiguous, use the vectorized kernel.
if ((total % 4 == 0) && x.is_contiguous()) {
int total_vec = total / 4;
int blocks = (total_vec + threads - 1) / threads;
const float4* input_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
fused_activation_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_vec, output_vec, total_vec);
} else {
int blocks = (total + threads - 1) / threads;
fused_activation_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), output.data_ptr<float>(), total);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_activation", &fused_activation, "Fused Mish and Tanh activation (vectorized if possible) using fast math intrinsics");
}