3.8. Residual Block Computation
Calculates output features by performing two sequential convolution operations each followed by batch normalization and a ReLU activation, then combines the result with an identity shortcut connection (which is optionally downsampled) to form a residual block transformation.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We want to further optimize our fused addition+ReLU operator in the ResNet basic block. The original fused kernel performs an element‐wise addition and then applies ReLU. However, we can boost performance by vectorizing the computation using float4 if the total number of elements in the output tensor is divisible by 4 and the pointers are properly aligned. This approach reduces the number of global memory transactions and improves memory throughput. In our optimized CUDA kernel, we write two versions: one non‐vectorized kernel and one vectorized kernel. At runtime, we check if the total number of elements is divisible by 4, and if so, we dispatch the vectorized kernel; otherwise, we fallback to the element‐wise kernel. We use c10::cuda::CUDAGuard to enforce device correctness and launch on the current CUDA stream. The Python module remains mostly unchanged, aside from loading this optimized kernel, and the CUDAModel class calls our fused add_relu operator. This improvement genuinely helps since it leverages vectorized memory operations to reduce the memory load and speed up the fused operation.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Non-vectorized fused kernel: for each element compute out = max(out + identity, 0)
__global__ void fused_add_relu_kernel(float* __restrict__ out,
const float* __restrict__ identity,
int numel) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < numel) {
float sum = out[index] + identity[index];
out[index] = sum > 0.0f ? sum : 0.0f;
}
}
// Vectorized fused kernel using float4: processes 4 elements per thread.
__global__ void fused_add_relu_vec_kernel(float4* __restrict__ out,
const float4* __restrict__ identity,
int numel_vec) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < numel_vec) {
// Load vectorized data
float4 o_val = out[index];
float4 i_val = identity[index];
// Compute component-wise addition and ReLU
float4 sum;
sum.x = o_val.x + i_val.x; sum.x = sum.x > 0.0f ? sum.x : 0.0f;
sum.y = o_val.y + i_val.y; sum.y = sum.y > 0.0f ? sum.y : 0.0f;
sum.z = o_val.z + i_val.z; sum.z = sum.z > 0.0f ? sum.z : 0.0f;
sum.w = o_val.w + i_val.w; sum.w = sum.w > 0.0f ? sum.w : 0.0f;
out[index] = sum;
}
}
torch::Tensor add_relu(torch::Tensor out, torch::Tensor identity) {
c10::cuda::CUDAGuard device_guard(out.device());
TORCH_CHECK(out.sizes() == identity.sizes(), "Tensors must have the same shape.");
int numel = out.numel();
int threads = 256;
// Check if we can use vectorized loads: if number of elements is divisible by 4 and data pointers are aligned.
if (numel % 4 == 0) {
int numel_vec = numel / 4;
int blocks = (numel_vec + threads - 1) / threads;
// Cast pointers to float4
float4* out_vec = reinterpret_cast<float4*>(out.data_ptr<float>());
const float4* identity_vec = reinterpret_cast<const float4*>(identity.data_ptr<float>());
fused_add_relu_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
out_vec, identity_vec, numel_vec);
} else {
int blocks = (numel + threads - 1) / threads;
fused_add_relu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<float>(), identity.data_ptr<float>(), numel);
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add_relu", &add_relu, "Fused addition and ReLU activation (optimized with vectorization when possible)");
}