1.76. Standard 1D Convolution with Dilation and Stride
Performs a 1D convolution on an asymmetric input using a square kernel. The operation calculates the convolution output with adjustable dilation and stride parameters and optionally adds a bias to the result.
Progress
Optimization Timeline
Step 1 Details
Reasoning
The task is to fuse the 1D convolution operation used in the original Model. In our fused kernel, we perform the convolution without constructing the full convolution operator. The kernel takes an input tensor (shape: [batch, in_channels, L_in]) and a weight tensor (shape: [out_channels, in_channels, kernel_size]) and computes the output using the standard convolution formula: output[b, oc, l_out] = Σ_ic Σ_k input[b, ic, l_out * stride + k * dilation] * weight[oc, ic, k] The output length is computed as: L_out = (L_in - ((kernel_size - 1) * dilation + 1)) / stride + 1 We parallelize over every output element (each combination of batch index, output channel, and output position) using a simple grid-stride loop. We use c10::cuda::CUDAGuard to ensure correct device usage, and we use torch::empty to allocate the output tensor. In the Python module, we define a CUDAModel class that initializes the weight parameter (and no bias, since bias is False by default) and forwards the input through our custom CUDA operator. The extension is loaded with extra_cuda_cflags set to "-O3" and "--use_fast_math". This fused kernel avoids the overhead of calling PyTorch’s native conv1d and should be much faster than instantiating a full convolution operator in Python.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Kernel to perform 1D convolution.
// Input shape: [batch, in_channels, L_in]
// Weight shape: [out_channels, in_channels, kernel_size]
// Output shape: [batch, out_channels, L_out], where
// L_out = (L_in - ((kernel_size - 1) * dilation + 1)) / stride + 1
__global__ void conv1d_kernel(const float* __restrict__ input,
const float* __restrict__ weight,
float* __restrict__ output,
int batch,
int in_channels,
int out_channels,
int L_in,
int kernel_size,
int stride,
int dilation,
int L_out) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * out_channels * L_out;
for (int idx = index; idx < total; idx += blockDim.x * gridDim.x) {
// Compute indices:
// idx = b * (out_channels * L_out) + oc * L_out + l_out
int temp = idx;
int l_out = temp % L_out;
temp /= L_out;
int oc = temp % out_channels;
int b = temp / out_channels;
float sum = 0.0f;
// For each input channel and kernel index compute the convolution sum.
for (int ic = 0; ic < in_channels; ic++) {
for (int k = 0; k < kernel_size; k++) {
int l_in = l_out * stride + k * dilation;
// Access input at [b, ic, l_in]
float val = input[b * (in_channels * L_in) + ic * L_in + l_in];
float w = weight[oc * (in_channels * kernel_size) + ic * kernel_size + k];
sum += val * w;
}
}
output[idx] = sum;
}
}
torch::Tensor conv1d_forward(torch::Tensor input, torch::Tensor weight, int stride, int dilation) {
// Enforce operations on the correct CUDA device.
c10::cuda::CUDAGuard device_guard(input.device());
// Get input dimensions: [batch, in_channels, L_in]
int batch = input.size(0);
int in_channels = input.size(1);
int L_in = input.size(2);
// Get convolution parameters from weight shape.
// weight shape: [out_channels, in_channels, kernel_size]
int out_channels = weight.size(0);
int kernel_size = weight.size(2);
// Compute output length:
// L_out = (L_in - ((kernel_size - 1) * dilation + 1)) / stride + 1
int effective_kernel = (kernel_size - 1) * dilation + 1;
int L_out = (L_in - effective_kernel) / stride + 1;
// Allocate output tensor of shape [batch, out_channels, L_out]
auto output = torch::empty({batch, out_channels, L_out}, input.options());
int total_elements = batch * out_channels * L_out;
int threads = 256;
int blocks = (total_elements + threads - 1) / threads;
conv1d_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<float>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
batch,
in_channels,
out_channels,
L_in,
kernel_size,
stride,
dilation,
L_out
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("conv1d_forward", &conv1d_forward, "Fused 1D convolution with stride and dilation");
}