2.74. 3D Transposed Convolution with LeakyReLU and Max Pooling
Performs a 3D transposed convolution to upsample the input, applies a LeakyReLU activation, multiplies the activated output by a learnable parameter, applies a second LeakyReLU activation, and finally calculates a max pooling operation to condense the features.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused element‐wise kernel by vectorizing operations when possible. In our fused operation (two LeakyReLU activations with an intermediate multiplication), the input tensor is 5D (N, C, D, H, W) stored in contiguous memory with the last (W) dimension varying fastest. If the total number of elements is divisible by 4 – and especially if the spatial volume (D*H*W) is divisible by 4 (ensuring that each channel’s data is an exact multiple of 4) – then we can cast the input and output pointers to float4 and process four elements per thread. In each vector load we compute the scalar index for the first element (v_idx*4) and then derive the channel index as: ch = ((v_idx*4) / spatial_size) % C Then each component of the float4 is processed (apply first LeakyReLU, multiply by the channel multiplier, then second LeakyReLU). This reduces the number of global memory accesses and can improve throughput by better leveraging coalesced loads and stores. In the host function we check if total_elements % 4 == 0 and spatial_size (i.e., D*H*W) % 4 == 0. If so, we dispatch the vectorized kernel; otherwise we fall back to the element‐wise kernel. These improvements are genuine as they reduce memory bandwidth and arithmetic instructions per scalar element, thereby accelerating the fused activation.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Element-wise fused activation kernel: applies LeakyReLU, multiplication, then second LeakyReLU.
// x shape: (N, C, D, H, W)
__global__ void fused_activation_kernel(const float* __restrict__ x,
const float* __restrict__ multiplier,
float* __restrict__ out,
int total_elements,
int C,
int spatial_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_elements) {
// Compute channel index: each channel has 'spatial_size' elements.
int ch = (idx / spatial_size) % C;
float val = x[idx];
float a1 = (val >= 0.0f) ? val : 0.2f * val;
float prod = a1 * multiplier[ch];
float a2 = (prod >= 0.0f) ? prod : 0.2f * prod;
out[idx] = a2;
}
}
// Vectorized kernel using float4. Assumes that total_elements and spatial_size are multiples of 4.
__global__ void fused_activation_vec_kernel(const float4* __restrict__ x,
const float* __restrict__ multiplier,
float4* __restrict__ out,
int vec_total, // total_elements / 4
int spatial_size, // D*H*W (per channel)
int C) {
int v_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (v_idx < vec_total) {
int scalar_index = v_idx * 4;
// Compute channel index based on first element of the vector.
int ch = (scalar_index / spatial_size) % C;
float4 in_val = x[v_idx];
float4 out_val;
// Process each component.
float vals[4] = { in_val.x, in_val.y, in_val.z, in_val.w };
#pragma unroll
for (int k = 0; k < 4; k++) {
float a1 = (vals[k] >= 0.0f) ? vals[k] : 0.2f * vals[k];
float prod = a1 * multiplier[ch];
float a2 = (prod >= 0.0f) ? prod : 0.2f * prod;
if (k == 0) out_val.x = a2;
else if (k == 1) out_val.y = a2;
else if (k == 2) out_val.z = a2;
else out_val.w = a2;
}
out[v_idx] = out_val;
}
}
torch::Tensor fused_activation(torch::Tensor x, torch::Tensor multiplier) {
// Enforce correct CUDA device.
const c10::cuda::CUDAGuard device_guard(x.device());
// x shape is assumed to be (N, C, D, H, W).
int total = x.numel();
int C_val = x.size(1);
int D = x.size(2);
int H = x.size(3);
int W = x.size(4);
int spatial_size = D * H * W; // elements per channel
auto output = torch::empty_like(x);
int threads = 256;
// If conditions for vectorization are met, launch the vectorized kernel.
if ((total % 4 == 0) && (spatial_size % 4 == 0)) {
int vec_total = total / 4;
int blocks = (vec_total + threads - 1) / threads;
// Cast x and output pointers to float4.
const float4* x_vec = reinterpret_cast<const float4*>(x.data_ptr<float>());
float4* out_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
fused_activation_vec_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x_vec,
multiplier.data_ptr<float>(),
out_vec,
vec_total,
spatial_size,
C_val
);
} else {
int blocks = (total + threads - 1) / threads;
fused_activation_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
multiplier.data_ptr<float>(),
output.data_ptr<float>(),
total,
C_val,
spatial_size
);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_activation", &fused_activation, "Fused LeakyReLU, multiplier, and LeakyReLU activation (CUDA)");
}