1.79. Transposed 1D Convolution Operation
Performs a transposed convolution on a 1D input using an asymmetric input shape with a square kernel. It supports configurable stride, padding, and dilation parameters to control the upsampling process.
Progress
Optimization Timeline
Step 4 Details
Reasoning
We can further optimize our kernel by unrolling the inner loop when kernel_size is fixed to 3—a common case—thus reducing loop overhead and enabling the compiler to better schedule instructions. In our unrolled version we precompute the common base value (j + padding) and then separately handle the three kernel positions. This reduces arithmetic redundancy inside the loops. We also continue to use __ldg to take advantage of the read-only cache. For generic kernel sizes, we retain our original loop. The Python module does not need any changes. These improvements reduce per-thread overhead and enhance performance when kernel_size == 3 without sacrificing generality.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
// CUDA kernel for transposed 1D convolution (no bias).
// Each thread computes one output element y[b, oc, j].
__global__ void conv_transpose1d_kernel(const float* __restrict__ x,
const float* __restrict__ weight,
float* __restrict__ output,
int B, int in_channels, int L_in,
int out_channels, int kernel_size, int L_out,
int stride, int padding, int dilation) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * out_channels * L_out;
if (index < total) {
// Map flat index to (b, oc, j)
int j = index % L_out;
int temp = index / L_out;
int oc = temp % out_channels;
int b = temp / out_channels;
float sum = 0.0f;
int base = j + padding; // Precomputed common term.
int base_x = b * (in_channels * L_in);
int weight_stride = out_channels * kernel_size;
// If kernel_size is 3, use unrolled loop for performance.
if (kernel_size == 3) {
// k = 0
int pos0 = base; // j + padding - 0*dilation.
if (pos0 % stride == 0) {
int i0 = pos0 / stride;
if (i0 >= 0 && i0 < L_in) {
for (int c = 0; c < in_channels; c++) {
int x_idx = base_x + c * L_in + i0;
int w_idx = c * weight_stride + oc * kernel_size + 0;
sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
}
}
}
// k = 1
int pos1 = base - dilation;
if (pos1 % stride == 0) {
int i1 = pos1 / stride;
if (i1 >= 0 && i1 < L_in) {
for (int c = 0; c < in_channels; c++) {
int x_idx = base_x + c * L_in + i1;
int w_idx = c * weight_stride + oc * kernel_size + 1;
sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
}
}
}
// k = 2
int pos2 = base - 2 * dilation;
if (pos2 % stride == 0) {
int i2 = pos2 / stride;
if (i2 >= 0 && i2 < L_in) {
for (int c = 0; c < in_channels; c++) {
int x_idx = base_x + c * L_in + i2;
int w_idx = c * weight_stride + oc * kernel_size + 2;
sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
}
}
}
} else {
// Generic loop for arbitrary kernel_size.
for (int k = 0; k < kernel_size; k++) {
int pos = base - k * dilation;
if (pos % stride != 0)
continue;
int i = pos / stride;
if (i < 0 || i >= L_in)
continue;
for (int c = 0; c < in_channels; c++) {
int x_idx = base_x + c * L_in + i;
int w_idx = c * weight_stride + oc * kernel_size + k;
sum += __ldg(&x[x_idx]) * __ldg(&weight[w_idx]);
}
}
}
int out_idx = b * (out_channels * L_out) + oc * L_out + j;
output[out_idx] = sum;
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor weight, int stride, int padding, int dilation) {
// Ensure we operate on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(x.device());
// x shape: [B, in_channels, L_in]
int B = x.size(0);
int in_channels = x.size(1);
int L_in = x.size(2);
// weight shape: [in_channels, out_channels, kernel_size]
int out_channels = weight.size(1);
int kernel_size = weight.size(2);
// Compute output length:
// L_out = (L_in - 1)*stride - 2*padding + dilation*(kernel_size - 1) + 1
int L_out = (L_in - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + 1;
auto output = torch::empty({B, out_channels, L_out}, x.options());
int total = B * out_channels * L_out;
int threads = 256;
int blocks = (total + threads - 1) / threads;
conv_transpose1d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
B, in_channels, L_in,
out_channels, kernel_size, L_out,
stride, padding, dilation
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Fused transposed 1D convolution (conv_transpose1d) using custom CUDA kernel");
}