3.15. DenseNet121 Feature Computation

Calculates image feature representations using densely connected blocks. Each dense block performs a sequence of batch normalization, ReLU activation, convolution, and dropout operations while concatenating its input with newly generated features. Transition layers subsequently downsample these features before a final global average pooling and linear mapping produce the classification output.

Progress

Step 2 of 2
Reward
1.07x

Optimization Timeline

Step 2 Details

Reasoning

The current fused_cat kernel for concatenation works correctly, but we can gain further speedup by taking advantage of vectorized memory accesses on CUDA. Since the concatenation occurs along the channel dimension and our tensors are stored in contiguous memory with layout [N, C, H, W], if we can ensure that the width (W) is divisible by 4 then each row actually has groups of 4 contiguous float values. In that case we can reinterpret the tensors as arrays of float4 and launch a vectorized kernel that processes 4 elements at a time. This reduces the number of global memory transactions and improves memory bandwidth utilization. In our updated CUDA code, we introduce a new kernel (fused_cat_vec_kernel) that is launched when W % 4 == 0. Otherwise, we fall back to the original scalar kernel. In the vectorized kernel we compute the output tensor in its vectorized form by computing indices based on a modified width (W/4). Each tensor A (with shape [N, C1, H, W]) and B ([N, C2, H, W]) is reinterpreted as having shape [N, C, H, W/4] in terms of float4 elements, and the output tensor will have shape [N, (C1+C2), H, W/4]. All operations then occur on the level of float4 values. This approach should provide a real speedup when the dimensions allow vectorization, without “cheating” – it’s a genuine memory access optimization. The Python code remains largely unchanged— it simply loads the CUDA extension with the proper flags. We now update only the CUDA code.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Scalar kernel for fused concatenation along channel dimension.
// Concatenates two 4D tensors A [N, C1, H, W] and B [N, C2, H, W] into output [N, C1+C2, H, W].
template <typename scalar_t>
__global__ void fused_cat_kernel(
    const scalar_t* __restrict__ A,
    const scalar_t* __restrict__ B,
    scalar_t* __restrict__ C,
    int N, int C1, int C2, int H, int W) {
  
  int total = N * (C1 + C2) * H * W;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < total) {
      int hw = H * W;
      int channels_total = C1 + C2;
      int n = index / (channels_total * hw);
      int idx_rem = index % (channels_total * hw);
      int c = idx_rem / hw;
      int rem = idx_rem % hw;
      int h = rem / W;
      int w = rem % W;
      if (c < C1) {
          int a_idx = ((n * C1 + c) * H + h) * W + w;
          C[index] = A[a_idx];
      } else {
          int b_idx = ((n * C2 + (c - C1)) * H + h) * W + w;
          C[index] = B[b_idx];
      }
  }
}

// Vectorized kernel for fused concatenation.
// Requires that the width (W) is divisible by 4, so that each row becomes W/4 float4 elements.
// The tensor A of shape [N, C1, H, W] is reinterpreted as [N, C1, H, W/4] of float4.
template <typename scalar_t>
__global__ void fused_cat_vec_kernel(
    const scalar_t* __restrict__ A,
    const scalar_t* __restrict__ B,
    scalar_t* __restrict__ C,
    int N, int C1, int C2, int H, int W) {
  
  // W must be divisible by 4.
  int W4 = W / 4;
  int total_vec = N * (C1 + C2) * H * W4;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < total_vec) {
      int channels_total = C1 + C2;
      int n = index / (channels_total * H * W4);
      int rem = index % (channels_total * H * W4);
      int c = rem / (H * W4);
      int rem2 = rem % (H * W4);
      int h = rem2 / W4;
      int w4 = rem2 % W4;
      
      // For proper vectorization, compute the corresponding vector index in A or B.
      if (c < C1) {
          int a_index = ((n * C1 + c) * H + h) * W4 + w4;
          const float4* A4 = reinterpret_cast<const float4*>(A);
          float4* C4 = reinterpret_cast<float4*>(C);
          C4[index] = A4[a_index];
      } else {
          int b_index = ((n * C2 + (c - C1)) * H + h) * W4 + w4;
          const float4* B4 = reinterpret_cast<const float4*>(B);
          float4* C4 = reinterpret_cast<float4*>(C);
          C4[index] = B4[b_index];
      }
  }
}

torch::Tensor cat(torch::Tensor A, torch::Tensor B) {
    // Validate that A and B are 4D CUDA tensors.
    TORCH_CHECK(A.is_cuda(), "Tensor A must be a CUDA tensor");
    TORCH_CHECK(B.is_cuda(), "Tensor B must be a CUDA tensor");
    TORCH_CHECK(A.dim() == 4, "Tensor A must be 4-dimensional");
    TORCH_CHECK(B.dim() == 4, "Tensor B must be 4-dimensional");
    TORCH_CHECK(A.size(0) == B.size(0), "Batch dimensions of A and B must match");
    TORCH_CHECK(A.size(2) == B.size(2), "Heights of A and B must match");
    TORCH_CHECK(A.size(3) == B.size(3), "Widths of A and B must match");

    int N = A.size(0);
    int C1 = A.size(1);
    int C2 = B.size(1);
    int H = A.size(2);
    int W = A.size(3);
    int out_channels = C1 + C2;
    auto options = A.options();
    auto output = torch::empty({N, out_channels, H, W}, options);

    int total = N * out_channels * H * W;
    int threads = 256;
    c10::cuda::CUDAGuard device_guard(A.device());
    
    // Use vectorized kernel if the width is divisible by 4.
    if (W % 4 == 0) {
        int W4 = W / 4;
        int total_vec = N * out_channels * H * W4;
        int blocks = (total_vec + threads - 1) / threads;
        AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "fused_cat_vec", ([&] {
            fused_cat_vec_kernel<scalar_t><<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
                A.data_ptr<scalar_t>(),
                B.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                N, C1, C2, H, W);
        }));
    } else {
        int blocks = (total + threads - 1) / threads;
        AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "fused_cat_scalar", ([&] {
            fused_cat_kernel<scalar_t><<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
                A.data_ptr<scalar_t>(),
                B.data_ptr<scalar_t>(),
                output.data_ptr<scalar_t>(),
                N, C1, C2, H, W);
        }));
    }
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("cat", &cat, "Fused concatenation of two 4D tensors along the channel dimension (vectorized when possible)");
}