2.32. 2D Convolution, Scaling, and Minimum Reduction

Performs a 2D convolution on an input tensor, scales the convolution output by a constant factor, and then calculates the minimum value along the channel dimension to produce a reduced output.

Progress

Step 5 of 5
Reward
2.92x

Optimization Timeline

Step 5 Details

Reasoning

We can further reduce overhead in our fused convolution, scaling, and minimum reduction operator by cleaning up redundant computations. In the general kernel, we compute the output indices (b, h, w) based solely on H_out and W_out. We also use __ldg intrinsics for loading bias, input, and weight values to help caching. The specialized kernel for kernel_size==3 already unrolls the loops and uses __ldg and now we simply remove redundant index computations. This yields a clear and efficient implementation that already achieves significant speedup over the native solution. Our Python module remains unchanged except for adjusting comments and minor improvements in code quality.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// General fused kernel: performs a 2D convolution, scales the result, then takes the min over output channels.
// __ldg is used to load read-only data from global memory.
__global__ void conv_min_kernel_general(const float* __restrict__ input,
                                         const float* __restrict__ weight,
                                         const float* __restrict__ bias,
                                         float* __restrict__ output,
                                         int B, int in_channels, int out_channels,
                                         int H_in, int W_in, int kernel_size,
                                         float scale_factor) {
    int H_out = H_in - kernel_size + 1;
    int W_out = W_in - kernel_size + 1;
    int total = B * H_out * W_out;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        // Compute batch and spatial indices using H_out and W_out.
        int b = idx / (H_out * W_out);
        int rem = idx % (H_out * W_out);
        int h = rem / W_out;
        int w = rem % W_out;
        float min_val = 1e38f;
        for (int c = 0; c < out_channels; c++) {
            float conv_val = __ldg(&bias[c]);
            for (int ic = 0; ic < in_channels; ic++) {
                for (int kh = 0; kh < kernel_size; kh++) {
                    for (int kw = 0; kw < kernel_size; kw++) {
                        int in_h = h + kh;
                        int in_w = w + kw;
                        int input_index = ((b * in_channels + ic) * H_in + in_h) * W_in + in_w;
                        int weight_index = (((c * in_channels + ic) * kernel_size + kh) * kernel_size) + kw;
                        conv_val += __ldg(&input[input_index]) * __ldg(&weight[weight_index]);
                    }
                }
            }
            conv_val *= scale_factor;
            if (conv_val < min_val)
                min_val = conv_val;
        }
        int out_index = ((b * 1) * H_out + h) * W_out + w;
        output[out_index] = min_val;
    }
}

// Optimized kernel for kernel_size == 3 with manual unrolling and __ldg for every load.
__global__ void conv_min_kernel_3(const float* __restrict__ input,
                                    const float* __restrict__ weight,
                                    const float* __restrict__ bias,
                                    float* __restrict__ output,
                                    int B, int in_channels, int out_channels,
                                    int H_in, int W_in, float scale_factor) {
    // For kernel size 3, the output dimensions are H_out = H_in - 2, W_out = W_in - 2.
    int H_out = H_in - 2;
    int W_out = W_in - 2;
    int total = B * H_out * W_out;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        int b = idx / (H_out * W_out);
        int rem = idx % (H_out * W_out);
        int h = rem / W_out;
        int w = rem % W_out;
        float min_val = 1e38f;
        for (int c = 0; c < out_channels; c++) {
            float conv_val = __ldg(&bias[c]);
            for (int ic = 0; ic < in_channels; ic++) {
                int base = ((b * in_channels + ic) * H_in);
                int row0 = base + h;
                int row1 = base + h + 1;
                int row2 = base + h + 2;
                int idx00 = row0 * W_in + w;
                int idx01 = row0 * W_in + (w + 1);
                int idx02 = row0 * W_in + (w + 2);
                int idx10 = row1 * W_in + w;
                int idx11 = row1 * W_in + (w + 1);
                int idx12 = row1 * W_in + (w + 2);
                int idx20 = row2 * W_in + w;
                int idx21 = row2 * W_in + (w + 1);
                int idx22 = row2 * W_in + (w + 2);
                
                int w_base = ((c * in_channels + ic) * 3);
                int w_idx0 = w_base * 3;       // weight[c,ic,0,0]
                int w_idx1 = w_idx0 + 1;         // weight[c,ic,0,1]
                int w_idx2 = w_idx0 + 2;         // weight[c,ic,0,2]
                int w_idx3 = (w_base + 1) * 3;   // weight[c,ic,1,0]
                int w_idx4 = w_idx3 + 1;         // weight[c,ic,1,1]
                int w_idx5 = w_idx3 + 2;         // weight[c,ic,1,2]
                int w_idx6 = (w_base + 2) * 3;   // weight[c,ic,2,0]
                int w_idx7 = w_idx6 + 1;         // weight[c,ic,2,1]
                int w_idx8 = w_idx6 + 2;         // weight[c,ic,2,2]
                
                conv_val += 
                    __ldg(&input[idx00]) * __ldg(&weight[w_idx0]) +
                    __ldg(&input[idx01]) * __ldg(&weight[w_idx1]) +
                    __ldg(&input[idx02]) * __ldg(&weight[w_idx2]) +
                    __ldg(&input[idx10]) * __ldg(&weight[w_idx3]) +
                    __ldg(&input[idx11]) * __ldg(&weight[w_idx4]) +
                    __ldg(&input[idx12]) * __ldg(&weight[w_idx5]) +
                    __ldg(&input[idx20]) * __ldg(&weight[w_idx6]) +
                    __ldg(&input[idx21]) * __ldg(&weight[w_idx7]) +
                    __ldg(&input[idx22]) * __ldg(&weight[w_idx8]);
            }
            conv_val *= scale_factor;
            if (conv_val < min_val)
                min_val = conv_val;
        }
        int out_index = ((b * 1) * H_out + h) * W_out + w;
        output[out_index] = min_val;
    }
}

torch::Tensor fused_forward(torch::Tensor input,
                              torch::Tensor weight,
                              torch::Tensor bias,
                              float scale_factor) {
    c10::cuda::CUDAGuard device_guard(input.device());
    
    int B = input.size(0);
    int in_channels = input.size(1);
    int H_in = input.size(2);
    int W_in = input.size(3);
    int out_channels = weight.size(0);
    int kernel_size = weight.size(2);  // Assuming square kernels.
    
    int H_out = H_in - kernel_size + 1;
    int W_out = W_in - kernel_size + 1;
    
    auto output = torch::empty({B, 1, H_out, W_out}, input.options());
    int total = B * H_out * W_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    if (kernel_size == 3) {
        conv_min_kernel_3<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            weight.data_ptr<float>(),
            bias.data_ptr<float>(),
            output.data_ptr<float>(),
            B, in_channels, out_channels,
            H_in, W_in, scale_factor);
    } else {
        conv_min_kernel_general<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
            input.data_ptr<float>(),
            weight.data_ptr<float>(),
            bias.data_ptr<float>(),
            output.data_ptr<float>(),
            B, in_channels, out_channels,
            H_in, W_in, kernel_size,
            scale_factor);
    }
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_forward", &fused_forward, "Fused convolution, scaling, and minimum reduction kernel");
}