4.11. Electra Small Discriminator Logits Computation
Calculates output logits by processing a sequence of token IDs through a transformer-based language model. The operation initializes the model using a pre-trained configuration and performs a forward pass on a randomly generated input sequence to produce prediction scores.
Progress
Optimization Timeline
Step 2 Details
Reasoning
We can further optimize the fused embedding lookup by vectorizing memory accesses when the hidden size is divisible by 4. Instead of each thread processing one float element (which causes non-coalesced loads for many cases) we can use float4 loads and stores. In our updated CUDA code, we check if weight.size(1) (i.e. the hidden size) is divisible by 4. If yes, we reinterpret the weight and output tensors as float4 and launch a vectorized kernel; otherwise, we fall back to the scalar kernel. This vectorized approach improves memory throughput and can reduce global memory traffic. Our Python module remains largely the same, but the CUDA operator now implements dynamic dispatch to select between the vectorized or scalar kernel. These improvements genuinely help by efficiently utilizing memory bandwidth without resorting to a hack.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
// Scalar kernel for embedding lookup.
// Each thread handles one element of the output tensor.
__global__ void embedding_lookup_kernel(const int64_t* __restrict__ indices,
const float* __restrict__ weight,
float* __restrict__ output,
int B, int S, int hidden_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * S * hidden_size;
if (idx < total) {
int d = idx % hidden_size;
int tmp = idx / hidden_size;
int b = tmp / S;
int s = tmp % S;
int64_t word_id = indices[b * S + s];
output[idx] = weight[word_id * hidden_size + d];
}
}
// Vectorized kernel for embedding lookup using float4.
// hidden_size is assumed to be divisible by 4.
__global__ void embedding_lookup_vec_kernel(const int64_t* __restrict__ indices,
const float4* __restrict__ weight,
float4* __restrict__ output,
int B, int S, int vec_hidden) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = B * S * vec_hidden; // vec_hidden = hidden_size / 4
if (idx < total) {
int d = idx % vec_hidden;
int tmp = idx / vec_hidden;
int b = tmp / S;
int s = tmp % S;
int64_t word_id = indices[b * S + s];
output[idx] = weight[word_id * vec_hidden + d];
}
}
// The embedding_forward function that is called from Python.
// It dispatches between a vectorized kernel (if hidden_size is divisible by 4)
// and a scalar kernel otherwise.
torch::Tensor embedding_forward(torch::Tensor indices, torch::Tensor weight) {
// Ensure the operation runs on the correct device.
c10::cuda::CUDAGuard device_guard(indices.device());
// Validate dimensions.
TORCH_CHECK(indices.dim() == 2, "indices tensor must be 2D");
TORCH_CHECK(weight.dim() == 2, "weight tensor must be 2D");
int B = indices.size(0);
int S = indices.size(1);
int hidden_size = weight.size(1);
// Create output tensor: shape [B, S, hidden_size].
auto output = torch::empty({B, S, hidden_size}, weight.options());
int threads = 256;
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
// If hidden_size is divisible by 4, use vectorized load/store.
if (hidden_size % 4 == 0) {
int vec_hidden = hidden_size / 4;
int total = B * S * vec_hidden;
int blocks = (total + threads - 1) / threads;
const float4* weight_vec = reinterpret_cast<const float4*>(weight.data_ptr<float>());
float4* output_vec = reinterpret_cast<float4*>(output.data_ptr<float>());
embedding_lookup_vec_kernel<<<blocks, threads, 0, stream>>>(
indices.data_ptr<int64_t>(),
weight_vec,
output_vec,
B, S, vec_hidden);
} else {
int total = B * S * hidden_size;
int blocks = (total + threads - 1) / threads;
embedding_lookup_kernel<<<blocks, threads, 0, stream>>>(
indices.data_ptr<int64_t>(),
weight.data_ptr<float>(),
output.data_ptr<float>(),
B, S, hidden_size);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("embedding_forward", &embedding_forward, "Fused Embedding Lookup (CUDA) with vectorized support");
}