1.27. SELU Activation
Calculates the Scaled Exponential Linear Unit (SELU) activation function on the input tensor, producing an output tensor of the same shape with each element transformed according to the SELU operation.
Progress
Optimization Timeline
Step 2 Details
Reasoning
The original SELU kernel applies the activation element‐wise using a simple loop over all elements. Since the input tensor size is large (e.g. 16×16384 = 262144 elements) and is usually divisible by 4, we can further optimize the kernel by leveraging vectorized loads and stores using float4. Instead of processing one float per thread, we can reinterpret the tensor as an array of float4 and process four floats at a time. This reduces the number of memory transactions and maximizes memory bandwidth. We will therefore implement a vectorized SELU kernel that checks if the total number of elements is divisible by 4; if so, we use the vectorized kernel, otherwise we fall back to the scalar kernel. We use c10::cuda::CUDAGuard and c10::cuda::getCurrentCUDAStream() as before. This optimization is a genuine improvement because it reduces the number of operations and memory accesses.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Standard SELU constants.
constexpr float SELU_ALPHA = 1.6732632423543772f;
constexpr float SELU_LAMBDA = 1.0507009873554805f;
// Scalar SELU kernel: applies SELU activation element-wise.
__global__ void selu_kernel(const float* __restrict__ input,
float* __restrict__ output,
int total) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total) {
float x = input[index];
output[index] = (x > 0.f) ? SELU_LAMBDA * x : SELU_LAMBDA * SELU_ALPHA * (expf(x) - 1.f);
}
}
// Vectorized SELU kernel using float4; processes 4 elements at a time.
__global__ void selu_kernel_vec(const float4* __restrict__ input,
float4* __restrict__ output,
int vec_total) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < vec_total) {
float4 in_val = input[index];
float4 out_val;
// Process each component of the float4 vector.
out_val.x = (in_val.x > 0.f) ? SELU_LAMBDA * in_val.x : SELU_LAMBDA * SELU_ALPHA * (expf(in_val.x) - 1.f);
out_val.y = (in_val.y > 0.f) ? SELU_LAMBDA * in_val.y : SELU_LAMBDA * SELU_ALPHA * (expf(in_val.y) - 1.f);
out_val.z = (in_val.z > 0.f) ? SELU_LAMBDA * in_val.z : SELU_LAMBDA * SELU_ALPHA * (expf(in_val.z) - 1.f);
out_val.w = (in_val.w > 0.f) ? SELU_LAMBDA * in_val.w : SELU_LAMBDA * SELU_ALPHA * (expf(in_val.w) - 1.f);
output[index] = out_val;
}
}
torch::Tensor selu_forward(torch::Tensor input) {
// Ensure correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
int total = input.numel();
auto output = torch::empty_like(input);
int threads = 256;
// Check if total is divisible by 4 for vectorized operations.
if (total % 4 == 0) {
int vec_total = total / 4;
int blocks = (vec_total + threads - 1) / threads;
// Reinterpret input and output pointers as float4 pointers.
const float4* input_vec = reinterpret_cast<const float4*>(input.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
selu_kernel_vec<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(input_vec, output_vec, vec_total);
} else {
int blocks = (total + threads - 1) / threads;
selu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(input.data_ptr<float>(), output.data_ptr<float>(), total);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("selu_forward", &selu_forward, "Custom CUDA SELU activation forward (vectorized when possible)");
}