3.28. Vision Transformer Classification

Calculates image classification outputs by splitting an image into patches, embedding them along with a learnable class token and positional encodings, processing the resulting sequence through transformer encoder layers, and finally computing class scores via a multi-layer perceptron.

Progress

Step 1 of 1
Reward
1.15x

Optimization Timeline

Step 1 Details

Reasoning

We can speed up the ViT model by fusing the two operations that extract patches from the input image and then project each patch via a linear transformation (i.e. the patch embedding step). In the original model, the patch extraction is done using a combination of tensor.unfold and reshape in Python, followed by a linear layer (self.patch_to_embedding). We can fuse these into a single custom CUDA kernel – “patch_embed_kernel” – that for each image patch in the input does the following:  • Reads a patch of size (channels, patch_size, patch_size) from the image.  • Flattens it and computes the linear projection using the weight and bias (i.e. computes dot(product) with each row of the weight matrix and adds the bias). The kernel will run with one thread per output element. The output tensor shape is [batch, num_patches, dim], where:  • batch is from the image (B),  • num_patches = (H/patch_size) * (W/patch_size) (assuming image size is divisible by patch_size),  • dim is the embedding dimension. This fused kernel reduces Python overhead and intermediate memory reads/writes. The remaining operations (insertion of a class token, addition of positional embedding, dropout, transformer encoder, and MLP-head) can remain unchanged. We now present a Python module that loads the CUDA extension with -O3 and --use_fast_math flags, and a CUDA kernel that implements the fused patch embedding.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

// CUDA kernel that fuses patch extraction and linear projection for a Vision Transformer.
// Input image tensor shape: [B, C, H, W]
// Weight shape: [dim, patch_dim] with patch_dim = C * patch_size * patch_size
// Bias shape: [dim]
// Output tensor shape: [B, num_patches, dim] where num_patches = (H/patch_size) * (W/patch_size)
__global__ void patch_embed_kernel(const float* __restrict__ img,
                                   const float* __restrict__ weight,
                                   const float* __restrict__ bias,
                                   float* __restrict__ out,
                                   int B, int C, int H, int W,
                                   int patch_size,
                                   int dim) {
    // Compute number of patches per image dimensions.
    int patches_per_row = W / patch_size;
    int num_patches = (H / patch_size) * patches_per_row;
    int patch_area = patch_size * patch_size;
    int patch_dim = C * patch_area;  // Flattened patch size.
    
    // Each thread computes one output element in the patch embedding.
    // Total output elements: B * num_patches * dim.
    int total = B * num_patches * dim;
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= total) return;
    
    // Determine batch index, patch index, and embedding dimension.
    int b = idx / (num_patches * dim);
    int rem = idx % (num_patches * dim);
    int patch_idx = rem / dim;
    int d = rem % dim;
    
    // Convert patch index into patch row and column.
    int pr = patch_idx / patches_per_row;
    int pc = patch_idx % patches_per_row;
    
    // Compute the linear projection: out = bias[d] + sum_{k=0}^{patch_dim-1} patch[k] * weight[d, k]
    float sum = bias[d];
    for (int k = 0; k < patch_dim; k++) {
        int c = k / patch_area;          // channel index.
        int pos = k % patch_area;          // position within patch.
        int r_offset = pos / patch_size;
        int c_offset = pos % patch_size;
        int h = pr * patch_size + r_offset;
        int w = pc * patch_size + c_offset;
        int img_index = b * (C * H * W) + c * (H * W) + h * W + w;
        // Weight is stored in row-major order: weight[d * patch_dim + k]
        sum += img[img_index] * weight[d * patch_dim + k];
    }
    out[idx] = sum;
}

torch::Tensor forward(torch::Tensor img, int patch_size, torch::Tensor weight, torch::Tensor bias) {
    // Ensure operations occur on the correct device.
    c10::cuda::CUDAGuard device_guard(img.device());
    
    // Image dimensions: [B, C, H, W].
    int B = img.size(0);
    int C = img.size(1);
    int H = img.size(2);
    int W = img.size(3);
    
    // Calculate the number of patches = (H / patch_size) * (W / patch_size).
    int num_patches = (H / patch_size) * (W / patch_size);
    // The embedding dimension is given by the weight tensor: weight shape is [dim, patch_dim].
    int dim = weight.size(0);
    
    // Allocate output tensor with shape [B, num_patches, dim].
    auto options = img.options();
    auto out = torch::empty({B, num_patches, dim}, options);
    
    int total = B * num_patches * dim;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    patch_embed_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        img.data_ptr<float>(),
        weight.data_ptr<float>(),
        bias.data_ptr<float>(),
        out.data_ptr<float>(),
        B, C, H, W,
        patch_size,
        dim
    );
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "Fused patch extraction and linear projection for ViT");
}