2.58. 3D Transposed Conv with Nonlinear Reductions

Calculates output feature maps by first performing a 3D transposed convolution to upsample the input, then computes a channel-wise LogSumExp reduction. It follows by applying a HardSwish-like activation via element-wise scaling with a sigmoid adjustment, subtracts a learnable bias, clamps values within [-1, 1], and finally selects the maximum value along the channel dimension.

Progress

Step 1 of 1
Reward
1.40x

Optimization Timeline

Step 1 Details

Reasoning

The original model performs several operations in sequence. It first applies a 3D transposed convolution to produce a tensor Y of shape [B, C, D, H, W]. Next, it computes the log‐sum‐exponential (logsumexp) across the channel dimension (dim=1, keepdim=True) so that the result L is of shape [B,1,D,H,W]. Then, a “HardSwish‐like” nonlinearity is applied via f = L * sigmoid(L+3) / 6. Next, a bias (with shape [C,1,1,1]) is subtracted; because L is independent of channel the subtraction is broadcast so that each channel becomes f minus a different bias value. Finally, the result is clamped to [–1,1] and the maximum along the channel dimension is taken. Observing that f is identical across channel indices, the final maximum simply equals f minus the minimum bias value. In other words, the fused operations are mathematically equivalent to computing:   L = logsumexp(Y, dim=1, keepdim=True)   f = L * sigmoid(L+3)/6   b_min = min(bias)   output = clamp(f – b_min, –1, 1) which yields an output tensor of shape [B,1,D,H,W]. To optimize, we fuse all these operations into one custom CUDA kernel. The kernel reads an input tensor Y (the output of the 3D transposed convolution) of shape [B, C, D, H, W] and, for each spatial location (for each (b, d, h, w)), it loops over channels (C is small, e.g., 16) to compute the logsumexp. Then, it computes f = L * sigmoid(L+3)/6. Finally, subtracting the precomputed b_min (obtained in Python from self.bias) and clamping the result to [–1,1] produces the final output value at that spatial location. In our Python module, we load the CUDA extension with -O3 and --use_fast_math. Our CUDAModel now retains a conv_transpose module and a bias parameter. In the forward method, we first compute y = conv_transpose(x); then compute b_min = torch.min(self.bias) (which is a scalar); and finally we call our fused CUDA kernel passing y and b_min. The kernel returns an output tensor of shape [B,1,D,H,W]. This fused implementation avoids launching separate kernels for logsumexp, sigmoid, element‐wise arithmetic, and max reduction, which helps shrink overhead and improve speed. --- FIX --- The error indicates that our CUDAModel constructor in our fusion module expects a scaling_factor argument, but the original problem’s __init__ provides only six parameters. Since the original model’s constructor does not include a scaling_factor, we need to remove that parameter from our CUDAModel’s constructor. This ensures that get_init_inputs (which returns six parameters) correctly matches the CUDAModel signature. We update the Python module to remove "scaling_factor" from both __init__() and forward. The fused operator in our CUDA kernel remains unchanged (it computes b_min subtraction, clamping, etc.). This fix will allow the model instantiation to succeed without a missing parameter error. We now show the revised Python and CUDA files.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Fused kernel that performs the following steps for each spatial location (b, d, h, w):
// 1. Compute LogSumExp over the channel dimension (L).
// 2. Compute f = L * sigmoid(L + 3) / 6.
// 3. Compute res = clamp(f - b_min, -1, 1).
//
// The input tensor has shape [B, C, D, H, W] and the output tensor has shape [B, 1, D, H, W].
__global__ void fused_kernel(const float* __restrict__ input,
                             float* __restrict__ output,
                             int B, int C, int D, int H, int W,
                             float b_min) {
    int S = D * H * W;  // total spatial locations per sample
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * S;
    if (index >= total) return;
    
    int b = index / S;
    int s = index % S;
    int d = s / (H * W);
    int r = s % (H * W);
    int h = r / W;
    int w = r % W;
    
    // Compute strides for input with shape [B, C, D, H, W]:
    // stride_B = C * D * H * W, stride_C = D * H * W, stride_D = H * W, stride_H = W, stride_W = 1.
    int stride_B = C * D * H * W;
    int stride_C = D * H * W;
    int stride_D = H * W;
    int stride_H = W;
    
    // Compute logsumexp over channels for location (b, :, d, h, w).
    float max_val = -INFINITY;
    for (int c = 0; c < C; c++) {
        float val = input[b * stride_B + c * stride_C + d * stride_D + h * stride_H + w];
        if (val > max_val)
            max_val = val;
    }
    float sum_exp = 0.0f;
    for (int c = 0; c < C; c++) {
        float val = input[b * stride_B + c * stride_C + d * stride_D + h * stride_H + w];
        sum_exp += expf(val - max_val);
    }
    float L = max_val + logf(sum_exp);
    
    // Compute f = L * sigmoid(L + 3) / 6.
    float sgm = 1.0f / (1.0f + expf(-(L + 3.0f)));
    float f = L * sgm / 6.0f;
    
    // Subtract b_min and clamp result to [-1, 1].
    float res = f - b_min;
    if (res < -1.0f) res = -1.0f;
    if (res > 1.0f) res = 1.0f;
    
    // Write result to output, which has shape [B, 1, D, H, W].
    // Compute flat index for output:
    int stride_B_out = D * H * W;  // only one channel in output.
    int out_idx = b * stride_B_out + d * (H * W) + h * W + w;
    output[out_idx] = res;
}

torch::Tensor forward(torch::Tensor input, float b_min) {
    // input shape: [B, C, D, H, W]
    c10::cuda::CUDAGuard device_guard(input.device());
    int B = input.size(0);
    int C = input.size(1);
    int D = input.size(2);
    int H = input.size(3);
    int W = input.size(4);
    
    // Output shape: [B, 1, D, H, W].
    auto output = torch::empty({B, 1, D, H, W}, input.options());
    
    int S = D * H * W;
    int total = B * S;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fused_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        B, C, D, H, W,
        b_min
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused transposed conv logsumexp hardswish subtract clamp and max reduction");
}