3.10. Residual Bottleneck Network Computation
Calculates a deep convolutional operation that chains together a series of bottleneck blocks. Each block performs sequential convolution, batch normalization, and ReLU activation while incorporating residual (skip) connections and downsampling. The operation concludes with adaptive pooling and a fully connected transformation for classification.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our fused residual addition plus ReLU kernel by reducing branch overhead. In our previous kernel we used a conditional operator to implement the ReLU activation. Instead, we can use the intrinsic function fmaxf() which computes the maximum of two floats in a branchless manner. This will likely reduce divergence and speed up the computation. We apply this change to both our scalar and vectorized versions. In the vectorized kernel we also use fmaxf() per component. Other than that, our design remains the same: if the total number of elements is divisible by 4, we use the vectorized float4 kernel; otherwise we use the scalar kernel. These changes are genuine because they reduce branch overhead and can lead to improved throughput on the GPU. Below are the updated Python and CUDA code blocks.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Scalar kernel: perform elementwise addition and ReLU using the branchless fmaxf.
__global__ void fused_add_relu_kernel(float* __restrict__ out,
const float* __restrict__ identity,
int numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
for (int i = idx; i < numel; i += stride) {
float sum = out[i] + identity[i];
out[i] = fmaxf(sum, 0.0f);
}
}
// Vectorized kernel using float4. Assumes numel is divisible by 4.
__global__ void fused_add_relu_vec_kernel(float4* __restrict__ out,
const float4* __restrict__ identity,
int num_vec) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
for (int i = idx; i < num_vec; i += stride) {
float4 o = out[i];
float4 id = identity[i];
float4 sum;
sum.x = fmaxf(o.x + id.x, 0.0f);
sum.y = fmaxf(o.y + id.y, 0.0f);
sum.z = fmaxf(o.z + id.z, 0.0f);
sum.w = fmaxf(o.w + id.w, 0.0f);
out[i] = sum;
}
}
torch::Tensor forward(torch::Tensor out, torch::Tensor identity) {
c10::cuda::CUDAGuard device_guard(out.device());
int numel = out.numel();
TORCH_CHECK(identity.numel() == numel, "Tensors must have the same number of elements.");
const int threads = 256;
// Use vectorized kernel if possible.
if ((numel % 4) == 0) {
int num_vec = numel / 4;
int blocks = (num_vec + threads - 1) / threads;
float4* out_vec = reinterpret_cast<float4*>(out.data_ptr<float>());
const float4* identity_vec = reinterpret_cast<const float4*>(identity.data_ptr<float>());
fused_add_relu_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
out_vec, identity_vec, num_vec);
} else {
int blocks = (numel + threads - 1) / threads;
fused_add_relu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<float>(), identity.data_ptr<float>(), numel);
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused elementwise add and ReLU with vectorization");
}