4.9. BigBird-Roberta Logits Computation

Calculates output logits by performing a forward pass using a pre-trained BigBird-Roberta configuration on a batch of 32 randomly generated token sequences, each of length 256.

Progress

Step 2 of 2
Reward
1.01x

Optimization Timeline

Step 2 Details

Reasoning

We can further optimize our custom embedding lookup kernel by reorganizing the computation. In the previous approach, one thread was launched per output element (i.e. overall batch_size×seq_length×hidden_dim threads). In our new approach we launch one CUDA block per token (each token from the input tensor of shape [batch, seq_length]). This kernel assigns a block to each token so that all threads in that block collaborate to load the token's embedding vector from the weight matrix. Since every element in the embedding vector belongs to the same token, the load from the weight matrix becomes contiguous, which improves memory coalescing. We also use the __ldg intrinsic to load weight values through the read-only cache. To do so, our kernel works as follows: 1. The grid dimension is set to the total number of tokens (i.e., batch_size * seq_length). Each block handles one token. 2. Each block first retrieves the token id from the flattened input tensor. 3. Then, each thread in that block loads one or more elements of the embedding (depending on hidden_dim and blockDim.x) from weight using __ldg, and writes into the corresponding location in the output. 4. The output is allocated with shape [batch, seq_length, hidden_dim]. This optimization reduces the overall number of threads launched and improves cache efficiency. The new approach is a genuine improvement by better tailoring the parallelism to the structure of an embedding lookup. We now update the CUDA kernel accordingly while the Python module remains largely the same except that it benefits from the more efficient embedding kernel.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <stdint.h>

// Optimized CUDA kernel for embedding lookup.
// Each CUDA block processes one token from the input tensor.
// The input tensor is of shape [batch_size, sequence_length] (flattened),
// and output tensor is of shape [batch_size, sequence_length, hidden_dim].
// Each block loads the embedding vector for the token using collaborative threads.
__global__ void embedding_lookup_kernel(const int64_t* __restrict__ input,
                                          const float* __restrict__ weight,
                                          float* __restrict__ output,
                                          int hidden_dim) {
    // Each block corresponds to one token.
    int token_idx = blockIdx.x;
    // Retrieve the token id for this block.
    int64_t token_id = input[token_idx];
    // Compute starting offset in output for this token.
    int out_offset = token_idx * hidden_dim;
    // Loop over the hidden dimension with stride corresponding to blockDim.x.
    for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
        // Use __ldg intrinsic for read-only cache.
        output[out_offset + i] = __ldg(&weight[token_id * hidden_dim + i]);
    }
}

// The embedding_forward function wraps the CUDA kernel.
// It takes an input token id tensor of shape [batch_size, sequence_length] and an embedding weight tensor of shape [vocab_size, hidden_dim],
// and returns an output tensor of shape [batch_size, sequence_length, hidden_dim].
torch::Tensor embedding_forward(torch::Tensor input, torch::Tensor weight) {
    // Ensure tensors are on CUDA and of proper type.
    TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
    TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
    TORCH_CHECK(input.scalar_type() == at::kLong, "input must be int64 (Long)");
    TORCH_CHECK(weight.scalar_type() == at::kFloat, "weight must be float");

    input = input.contiguous();
    weight = weight.contiguous();
    
    // Get dimensions.
    int batch = input.size(0);
    int seq_length = input.size(1);
    int hidden_dim = weight.size(1);
    
    // Total number of tokens (each token will be processed by one block).
    int total_tokens = batch * seq_length;
    
    // Allocate the output tensor: [batch, seq_length, hidden_dim]
    auto output = torch::empty({batch, seq_length, hidden_dim}, weight.options());
    
    // Choose block size: use min(hidden_dim, 256) for efficiency.
    int blockSize = (hidden_dim < 256) ? hidden_dim : 256;
    
    // Launch one block per token.
    c10::cuda::CUDAGuard device_guard(input.device());
    embedding_lookup_kernel<<<total_tokens, blockSize, 0, c10::cuda::getCurrentCUDAStream()>>>(
        input.data_ptr<int64_t>(),
        weight.data_ptr<float>(),
        output.data_ptr<float>(),
        hidden_dim);
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("embedding_forward", &embedding_forward, "Optimized CUDA Embedding Lookup Kernel");
}