2.24. 3D Convolution Minimum Softmax
Performs a 3D convolution on volumetric input, calculates the minimum values along a specified dimension to reduce the depth, and then computes the softmax activation across channels.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize our fused kernel by reducing the number of passes over the channels. In our original design, after computing per-channel minimum values we made one pass to compute the maximum (needed for softmax normalization), another pass to compute the sum of exponentials, and a final pass to write out the normalized softmax values. We can reduce one loop by computing the maximum on the fly while computing the minimum for each channel. Then, in a second pass, we compute the exponentials (storing the result back into our temporary array) and compute the sum. Finally, we perform the third pass to write the normalized values. This merging of the max computation with the min reduction reduces loop overhead and the use of temporary variables. We also add "#pragma unroll" hints for inner loops over the D dimension and the subsequent loops if the channel count is small, which is common in these models. The Python module remains nearly the same. We update only the CUDA kernel to incorporate these improvements.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <float.h>
#include <math.h>
#define MAX_CHANNELS 64
// Fused kernel that, for each (b, h, w):
// For each channel c, computes min_{d} input[b, c, d, h, w].
// Simultaneously computes the maximum over channels for normalization.
// Then, in a second pass, computes the exponential values and reduces their sum.
// Finally, writes: softmax[b, c, h, w] = exp(min_val - max) / sum_exp.
__global__ void fused_min_softmax_kernel(const float* __restrict__ input,
float* __restrict__ output,
int B, int C, int D, int H, int W) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * H * W;
if (idx >= total) return;
// Compute indices for batch, h, w.
int b = idx / (H * W);
int rem = idx % (H * W);
int h = rem / W;
int w = rem % W;
// Temporary storage for min values for each channel.
float min_vals[MAX_CHANNELS];
float max_val = -FLT_MAX;
// First pass: for each channel, compute minimum along D and update max_val.
for (int c = 0; c < C; c++) {
float local_min = FLT_MAX;
#pragma unroll
for (int d = 0; d < D; d++) {
int index = b * (C * D * H * W) + c * (D * H * W) + d * (H * W) + h * W + w;
float val = __ldg(&input[index]);
if (val < local_min)
local_min = val;
}
min_vals[c] = local_min;
if (local_min > max_val)
max_val = local_min;
}
// Second pass: calculate exponentials (store back into min_vals) and sum.
float sum_exp = 0.0f;
for (int c = 0; c < C; c++) {
float exp_val = expf(min_vals[c] - max_val);
min_vals[c] = exp_val; // Reuse storage for the exponential value.
sum_exp += exp_val;
}
// Third pass: compute softmax output.
for (int c = 0; c < C; c++) {
float softmax_val = min_vals[c] / sum_exp;
int out_index = b * (C * H * W) + c * (H * W) + h * W + w;
output[out_index] = softmax_val;
}
}
torch::Tensor forward(torch::Tensor input) {
c10::cuda::CUDAGuard device_guard(input.device());
// Input shape: [B, C, D, H, W]
int B = input.size(0);
int C = input.size(1);
int D = input.size(2);
int H = input.size(3);
int W = input.size(4);
TORCH_CHECK(C <= MAX_CHANNELS, "Channel dimension exceeds maximum supported value.");
// Output shape: [B, C, H, W] (D is reduced).
auto output = torch::empty({B, C, H, W}, input.options());
int total_threads = B * H * W;
int threads = 256;
int blocks = (total_threads + threads - 1) / threads;
fused_min_softmax_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(), output.data_ptr<float>(), B, C, D, H, W
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused min reduction along D and softmax over channel dimension");
}