1.25. Swish Activation Computation
Calculates the Swish activation function by multiplying the input with its sigmoid value. This operation transforms the input data using the formula x * sigmoid(x).
Progress
Step 4 of 4
Reward
1.58x
Optimization Timeline
Step 4 Details
Reasoning
We can further optimize the Swish kernel by processing as many elements as possible in vectorized form even when the total number of elements is not divisible by 4. Instead of switching entirely to the element‑wise kernel when numel isn’t divisible by 4, we can process the bulk (floor(numel/4)*4 elements) with our optimized vectorized kernel and then handle the remaining few elements with a fallback element‑wise kernel. This hybrid approach reduces memory transactions and leverages vectorized computation for most of the data, which is a genuine improvement in throughput.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Element-wise kernel for Swish activation: y = x * sigmoid(x)
__global__ void swish_kernel(const float* __restrict__ input,
float* __restrict__ output,
int numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
for (; idx < numel; idx += stride) {
float x = input[idx];
float sigmoid_x = 1.f / (1.f + __expf(-x)); // fast exponential
output[idx] = x * sigmoid_x;
}
}
// Vectorized kernel using float4: processes 4 elements per thread.
__global__ void swish_vec_kernel(const float4* __restrict__ input,
float4* __restrict__ output,
int vec_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
for (; idx < vec_size; idx += stride) {
float4 in_val = input[idx];
float4 out_val;
out_val.x = in_val.x * (1.f / (1.f + __expf(-in_val.x)));
out_val.y = in_val.y * (1.f / (1.f + __expf(-in_val.y)));
out_val.z = in_val.z * (1.f / (1.f + __expf(-in_val.z)));
out_val.w = in_val.w * (1.f / (1.f + __expf(-in_val.w)));
output[idx] = out_val;
}
}
torch::Tensor forward(torch::Tensor input) {
// Ensure computation happens on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
int numel = input.numel();
auto output = torch::empty_like(input);
int threads = 256;
// Process the bulk using vectorized kernel.
int vec_num = numel / 4; // number of float4 elements
int remainder = numel % 4; // remaining elements that don't form a complete float4
if (vec_num > 0) {
int blocks = (vec_num + threads - 1) / threads;
const float4* in_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* out_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
swish_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
in_vec, out_vec, vec_num);
}
// Process the remaining tail using the element-wise kernel.
if (remainder > 0) {
int offset = numel - remainder;
int blocks = (remainder + threads - 1) / threads;
swish_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>() + offset,
output.data_ptr<float>() + offset,
remainder);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Swish activation using custom CUDA kernel with hybrid vectorized processing");
}