3.9. Deep Residual Network Computation

Performs a forward pass through a deep convolutional network using sequential computation of residual blocks. Each block computes convolution operations, applies normalization and non-linear activation, and adds shortcut connections to combine the raw input with transformed features. The network progressively downsamples the input before aggregating features via adaptive pooling and outputting class scores.

Progress

Step 3 of 3
Reward
1.08x

Optimization Timeline

Step 3 Details

Reasoning

We can further improve the fused addition+ReLU kernel by replacing the ternary operator with the intrinsic fmaxf, which is typically more efficient and can be inlined by the compiler. In our vectorized kernel we apply fmaxf component‐wise so that each float4 element is processed with little branching overhead. This minimizes divergent execution and reduces the overall kernel runtime. The rest of the design remains the same: we dispatch between a scalar version and a vectorized version if the number of elements is divisible by 4. These changes are genuine improvements because they reduce the instruction complexity and leverage the hardware’s fast max operations. The Python module remains largely unchanged.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Scalar fused kernel: computes output[i] = fmaxf(x[i] + identity[i], 0) for each element.
__global__ void fused_add_relu_kernel(const float* __restrict__ x,
                                      const float* __restrict__ identity,
                                      float* __restrict__ output,
                                      int numel) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = gridDim.x * blockDim.x;
    for (int i = idx; i < numel; i += stride) {
        float sum = x[i] + identity[i];
        output[i] = fmaxf(sum, 0.f);
    }
}

// Vectorized fused kernel: processes 4 floats at a time.
__global__ void fused_add_relu_vec_kernel(const float4* __restrict__ x,
                                          const float4* __restrict__ identity,
                                          float4* __restrict__ output,
                                          int num_vec) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = gridDim.x * blockDim.x;
    for (int i = idx; i < num_vec; i += stride) {
        float4 xv = x[i];
        float4 idv = identity[i];
        float4 sum;
        sum.x = xv.x + idv.x;
        sum.y = xv.y + idv.y;
        sum.z = xv.z + idv.z;
        sum.w = xv.w + idv.w;
        // Apply ReLU to each component using fmaxf.
        sum.x = fmaxf(sum.x, 0.f);
        sum.y = fmaxf(sum.y, 0.f);
        sum.z = fmaxf(sum.z, 0.f);
        sum.w = fmaxf(sum.w, 0.f);
        output[i] = sum;
    }
}

torch::Tensor fused_add_relu(torch::Tensor x, torch::Tensor identity) {
    // Ensure proper device usage.
    c10::cuda::CUDAGuard device_guard(x.device());
    int numel = x.numel();
    auto output = torch::empty_like(x);
    int threads = 256;
    // Check if we can dispatch the vectorized kernel.
    if (numel % 4 == 0) {
        int num_vec = numel / 4;
        int blocks = (num_vec + threads - 1) / threads;
        fused_add_relu_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            reinterpret_cast<const float4*>(x.data_ptr<float>()),
            reinterpret_cast<const float4*>(identity.data_ptr<float>()),
            reinterpret_cast<float4*>(output.data_ptr<float>()),
            num_vec
        );
    } else {
        int blocks = (numel + threads - 1) / threads;
        fused_add_relu_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            x.data_ptr<float>(),
            identity.data_ptr<float>(),
            output.data_ptr<float>(),
            numel
        );
    }
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_add_relu", &fused_add_relu, "Fused element-wise addition and ReLU using vectorized loads");
}