2.93. Transposed Convolution with Offset, Clamp, GELU, and Scaling
Performs a 2D transposed convolution on the input to upsample spatial features, adds a constant offset, clamps each element by taking the minimum between the value and zero, applies the Gaussian Error Linear Unit (GELU) activation, and finally scales the activated output by a fixed multiplier.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our fused CUDA kernel by ensuring that critical device functions are inlined. Adding __forceinline__ to our fast_gelu() and fused_process() functions can reduce function call overhead during kernel execution. Additionally, we can add a #pragma unroll directive in our vectorized kernel loop for the four elements to encourage loop unrolling. These improvements, although subtle, help maximize throughput and reduce overhead, especially when running on many elements.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Fast approximate GELU using a tanh formulation.
// __forceinline__ to ensure the function is inlined for performance.
__device__ __forceinline__ float fast_gelu(float x) {
float x_cubed = x * x * x;
// Use fused multiply-add: inner = x + 0.044715*x^3
float inner = fmaf(0.044715f, x_cubed, x);
// Use tanhf from the device math library.
float t = tanhf(0.7978845608028654f * inner);
return 0.5f * x * (1.0f + t);
}
// Process a single element: add, clamp (min with 0), apply GELU, then scale.
// __forceinline__ to reduce inlining overhead.
__device__ __forceinline__ float fused_process(float in_val, float add_value, float multiply_value) {
float v = in_val + add_value;
v = fminf(v, 0.0f); // clamp positive values to 0 without branching
v = fast_gelu(v);
return v * multiply_value;
}
// Scalar kernel operating on each element.
__global__ void fused_post_ops_kernel(const float* __restrict__ input,
float* __restrict__ output,
int total,
float add_value,
float multiply_value) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
float x = input[idx] + add_value;
x = fminf(x, 0.0f);
x = fast_gelu(x);
output[idx] = x * multiply_value;
}
}
// Vectorized kernel operating on float4 if total number of elements is divisible by 4.
__global__ void fused_post_ops_vec_kernel(const float4* __restrict__ input,
float4* __restrict__ output,
int total4,
float add_value,
float multiply_value) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total4) {
float4 in_val = input[idx];
float4 out_val;
// Applying operations on each component.
#pragma unroll
for (int i = 0; i < 4; i++) {
// Use pointer arithmetic to handle each field of float4.
((float*)&out_val)[i] = fused_process(((float*)&in_val)[i], add_value, multiply_value);
}
output[idx] = out_val;
}
}
torch::Tensor forward(torch::Tensor input, float add_value, float multiply_value) {
// Enforce device correctness with CUDAGuard.
c10::cuda::CUDAGuard device_guard(input.device());
int total = input.numel();
auto output = torch::empty_like(input);
const int threads = 256;
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
// If total elements is divisible by 4, use the vectorized kernel.
if (total % 4 == 0) {
int total4 = total / 4;
int blocks = (total4 + threads - 1) / threads;
const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
fused_post_ops_vec_kernel<<<blocks, threads, 0, stream>>>(input_vec, output_vec, total4, add_value, multiply_value);
} else {
int blocks = (total + threads - 1) / threads;
fused_post_ops_kernel<<<blocks, threads, 0, stream>>>(input.data_ptr<float>(), output.data_ptr<float>(), total, add_value, multiply_value);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused elementwise post-operations (add, min, GELU, scale) with vectorized optimization and inlined functions");
}