1.79. Transposed 1D Convolution Operation

Performs a transposed convolution on a 1D input using an asymmetric input shape with a square kernel. It supports configurable stride, padding, and dilation parameters to control the upsampling process.

Progress

Step 4 of 4
Reward
2.63x

Optimization Timeline

Step 4 Details

Reasoning

We can further optimize our kernel by unrolling the inner loop when kernel_size is fixed to 3—a common case—thus reducing loop overhead and enabling the compiler to better schedule instructions. In our unrolled version we precompute the common base value (j + padding) and then separately handle the three kernel positions. This reduces arithmetic redundancy inside the loops. We also continue to use __ldg to take advantage of the read-only cache. For generic kernel sizes, we retain our original loop. The Python module does not need any changes. These improvements reduce per-thread overhead and enhance performance when kernel_size == 3 without sacrificing generality.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// CUDA kernel for transposed 1D convolution (no bias).
// Each thread computes one output element y[b, oc, j].
__global__ void conv_transpose1d_kernel(const float* __restrict__ x,
                                          const float* __restrict__ weight,
                                          float* __restrict__ output,
                                          int B, int in_channels, int L_in,
                                          int out_channels, int kernel_size, int L_out,
                                          int stride, int padding, int dilation) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int total = B * out_channels * L_out;
    if (index < total) {
        // Map flat index to (b, oc, j)
        int j = index % L_out;
        int temp = index / L_out;
        int oc = temp % out_channels;
        int b = temp / out_channels;
        
        float sum = 0.0f;
        int base = j + padding; // Precomputed common term.
        int base_x = b * (in_channels * L_in);
        int weight_stride = out_channels * kernel_size;
        
        // If kernel_size is 3, use unrolled loop for performance.
        if (kernel_size == 3) {
            // k = 0
            int pos0 = base; // j + padding - 0*dilation.
            if (pos0 % stride == 0) {
                int i0 = pos0 / stride;
                if (i0 >= 0 && i0 < L_in) {
                    for (int c = 0; c < in_channels; c++) {
                        int x_idx = base_x + c * L_in + i0;
                        int w_idx = c * weight_stride + oc * kernel_size + 0;
                        sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
                    }
                }
            }
            // k = 1
            int pos1 = base - dilation;
            if (pos1 % stride == 0) {
                int i1 = pos1 / stride;
                if (i1 >= 0 && i1 < L_in) {
                    for (int c = 0; c < in_channels; c++) {
                        int x_idx = base_x + c * L_in + i1;
                        int w_idx = c * weight_stride + oc * kernel_size + 1;
                        sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
                    }
                }
            }
            // k = 2
            int pos2 = base - 2 * dilation;
            if (pos2 % stride == 0) {
                int i2 = pos2 / stride;
                if (i2 >= 0 && i2 < L_in) {
                    for (int c = 0; c < in_channels; c++) {
                        int x_idx = base_x + c * L_in + i2;
                        int w_idx = c * weight_stride + oc * kernel_size + 2;
                        sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
                    }
                }
            }
        } else {
            // Generic loop for arbitrary kernel_size.
            for (int k = 0; k < kernel_size; k++) {
                int pos = base - k * dilation;
                if (pos % stride != 0)
                    continue;
                int i = pos / stride;
                if (i < 0 || i >= L_in)
                    continue;
                for (int c = 0; c < in_channels; c++) {
                    int x_idx = base_x + c * L_in + i;
                    int w_idx = c * weight_stride + oc * kernel_size + k;
                    sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
                }
            }
        }
        int out_idx = b * (out_channels * L_out) + oc * L_out + j;
        output[out_idx] = sum;
    }
}

torch::Tensor forward(torch::Tensor x, torch::Tensor weight, int stride, int padding, int dilation) {
    // Ensure we operate on the correct CUDA device.
    c10::cuda::CUDAGuard device_guard(x.device());
    
    // x shape: [B, in_channels, L_in]
    int B = x.size(0);
    int in_channels = x.size(1);
    int L_in = x.size(2);
    
    // weight shape: [in_channels, out_channels, kernel_size]
    int out_channels = weight.size(1);
    int kernel_size = weight.size(2);
    
    // Compute output length:
    // L_out = (L_in - 1)*stride - 2*padding + dilation*(kernel_size - 1) + 1
    int L_out = (L_in - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + 1;
    
    auto output = torch::empty({B, out_channels, L_out}, x.options());
    
    int total = B * out_channels * L_out;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    conv_transpose1d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        weight.data_ptr<float>(),
        output.data_ptr<float>(),
        B, in_channels, L_in,
        out_channels, kernel_size, L_out,
        stride, padding, dilation
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused transposed 1D convolution (conv_transpose1d) using custom CUDA kernel");
}