2.4. 2D Convolution with Dual Mish Activation
Performs a 2D convolution operation on input data using specified channels and a given kernel size, then sequentially applies the Mish activation function twice to transform the resulting feature map.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further reduce memory latency by explicitly using __ldg to load input values from global memory into the read‐only cache. In addition, by marking our device functions as __forceinline__, we urge the compiler to inline our mish functions, thereby eliminating function call overhead. In our vectorized kernel, we now load a float4 using __ldg, which may slightly optimize memory transactions. We keep our grid‐stride loop and both vectorized and element‐wise paths. These changes are genuine micro‐optimizations that build upon our previous vectorized version, and they are not workarounds – they directly target lower latency in memory access and function call overhead.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Inline device function for the Mish activation.
// mish(x) = x * tanh(softplus(x)) with softplus(x) = log(1 + exp(x))
__forceinline__ __device__ float mish(float x) {
float sp = log1pf(expf(x));
return x * tanhf(sp);
}
// Elementwise kernel: computes output = mish(mish(input[i])) for each element.
// Uses __ldg for potentially faster read from global memory.
__global__ void fused_mish2_kernel(const float* __restrict__ input,
float* __restrict__ output,
int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = idx; i < total; i += stride) {
float x = __ldg(&input[i]);
float v = mish(x);
output[i] = mish(v);
}
}
// Vectorized kernel using float4 when total elements is divisible by 4.
// Each thread processes a float4 (i.e. 4 floats) at a time.
__global__ void fused_mish2_vec_kernel(const float4* __restrict__ input,
float4* __restrict__ output,
int total_vec) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = idx; i < total_vec; i += stride) {
// Use __ldg for reading the float4 from global memory.
float4 in_val = __ldg(&input[i]);
float4 out_val;
out_val.x = mish(mish(in_val.x));
out_val.y = mish(mish(in_val.y));
out_val.z = mish(mish(in_val.z));
out_val.w = mish(mish(in_val.w));
output[i] = out_val;
}
}
torch::Tensor forward(torch::Tensor input) {
// Ensure we are on the correct device.
c10::cuda::CUDAGuard device_guard(input.device());
int total = input.numel();
auto output = torch::empty_like(input);
int threads = 256;
// If total elements is divisible by 4, use vectorized kernel.
if (total % 4 == 0) {
int total_vec = total / 4;
int blocks = (total_vec + threads - 1) / threads;
const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
fused_mish2_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_vec, output_vec, total_vec);
} else {
int blocks = (total + threads - 1) / threads;
fused_mish2_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), output.data_ptr<float>(), total);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused double Mish activation kernel with vectorized path using __ldg and __forceinline__");
}