1.19. ReLU Activation
Computes the element-wise Rectified Linear Unit (ReLU) activation function on an input tensor by setting negative values to zero while leaving positive values unchanged.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the ReLU kernel by using vectorized memory loads when the total number of elements in the input tensor is divisible by 4. In this case, we can reinterpret the input and output as float4 arrays and process four floats per thread. This reduces the number of memory accesses and increases throughput. In our CUDA extension, we’ll check if x.numel() is divisible by 4. If it is, we launch a vectorized kernel; otherwise, we fall back to the standard element‐wise kernel. We continue using c10::cuda utilities for device correctness and launch on c10::cuda::getCurrentCUDAStream(). The Python module remains similar but automatically dispatches the optimized kernel. Before finalizing, I ask: Have my improvements genuinely helped? Using vectorized loads for element-wise operations like ReLU can reduce the total number of iterations by a factor of four, reducing memory traffic and increasing performance, so yes.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Standard element-wise ReLU kernel.
__global__ void relu_kernel(const float* __restrict__ input,
float* __restrict__ output,
int numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
float val = input[idx];
output[idx] = val > 0.0f ? val : 0.0f;
}
}
// Vectorized ReLU kernel using float4.
__global__ void relu_kernel_vec(const float4* __restrict__ input,
float4* __restrict__ output,
int total_vec) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_vec) {
float4 v = input[idx];
// Process each component of the vector.
v.x = (v.x > 0.0f) ? v.x : 0.0f;
v.y = (v.y > 0.0f) ? v.y : 0.0f;
v.z = (v.z > 0.0f) ? v.z : 0.0f;
v.w = (v.w > 0.0f) ? v.w : 0.0f;
output[idx] = v;
}
}
torch::Tensor forward(torch::Tensor x) {
// Ensure kernel runs on the correct device.
c10::cuda::CUDAGuard device_guard(x.device());
// Allocate output tensor with same shape as input.
auto output = torch::empty_like(x);
int numel = x.numel();
int threads = 256;
// Check if we can use vectorized kernel (if numel is divisible by 4).
if (numel % 4 == 0) {
int total_vec = numel / 4;
int blocks = (total_vec + threads - 1) / threads;
const float4* in_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
float4* out_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
relu_kernel_vec<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
in_vec, out_vec, total_vec);
} else {
int blocks = (numel + threads - 1) / threads;
relu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), output.data_ptr<float>(), numel);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Custom ReLU activation using a vectorized CUDA kernel where possible");
}