3.40. GRU Hidden State Calculation
Calculates the final recurrent hidden state by processing an input sequence through multiple layers of a gated recurrent unit (GRU), starting from an initially randomized hidden state.
Progress
Optimization Timeline
Step 3 Details
Reasoning
We can further optimize the kernel by ensuring the inner dot‐product routine is inlined and uses the read‐only cache via __ldg. In our updated CUDA kernel we add the __forceinline__ directive to our dot_vectorized routine and wrap the load accesses with __ldg. This should reduce function call overhead and improve memory latency for reading weights. The overall kernel organization remains similar – each block processes one batch element and each thread computes one hidden unit’s recurrent state, while shared memory is used to load the per‐time-step input vector and the previous hidden state. We have kept the necessary synchronizations. These improvements should yield a small performance boost without altering functional correctness.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>
// Device inline function for sigmoid.
__device__ inline float sigmoidf(float x) {
return 1.0f / (1.0f + expf(-x));
}
// Optimized and inlined dot product function using vectorized loads and __ldg.
__device__ __forceinline__ float dot_vectorized(const float* a, const float* b, int len) {
float sum = 0.0f;
int vec_len = len / 4; // Assume len is divisible by 4; process remainder separately.
const float4* a4 = reinterpret_cast<const float4*>(a);
const float4* b4 = reinterpret_cast<const float4*>(b);
for (int i = 0; i < vec_len; i++) {
float4 va = __ldg(a4 + i);
float4 vb = __ldg(b4 + i);
sum += va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w;
}
for (int i = vec_len * 4; i < len; i++) {
sum += __ldg(a + i) * __ldg(b + i);
}
return sum;
}
// Optimized fused GRU kernel for a single-layer GRU.
// Each block processes one batch element.
// Each thread (indexed by threadIdx.x) computes one element of the hidden state.
// Shared memory layout: first 'hidden_size' floats for hidden state, then 'input_size' floats for input vector.
__global__ void fused_gru_single_layer_optimized_kernel(const float* __restrict__ x,
const float* __restrict__ h0,
float* __restrict__ h_out,
int seq_len, int batch_size, int input_size, int hidden_size,
const float* __restrict__ weight_ih,
const float* __restrict__ weight_hh,
const float* __restrict__ bias_ih,
const float* __restrict__ bias_hh) {
int b = blockIdx.x; // Each block processes one batch element.
int i = threadIdx.x; // Each thread processes one hidden unit.
// Shared memory: first hidden_size floats for hidden state, then input_size floats for the input vector.
extern __shared__ float smem[];
float* s_h = smem; // Hidden state buffer.
float* s_x = smem + hidden_size; // Input vector buffer.
// Load initial hidden state for element i.
if (i < hidden_size) {
s_h[i] = h0[b * hidden_size + i];
}
__syncthreads();
// Loop over time steps.
for (int t = 0; t < seq_len; t++) {
const float* x_ptr = x + t * (batch_size * input_size) + b * input_size;
// Cooperative loading of input vector into shared memory.
for (int idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
s_x[idx] = x_ptr[idx];
}
__syncthreads();
// Determine row indices for the three gates.
int r_row = i; // Reset gate row index.
int z_row = hidden_size + i; // Update gate row index.
int n_row = 2 * hidden_size + i; // New gate row index.
// Compute dot products for the input contribution.
float r_input = dot_vectorized(weight_ih + r_row * input_size, s_x, input_size);
float z_input = dot_vectorized(weight_ih + (hidden_size + i) * input_size, s_x, input_size);
float n_input = dot_vectorized(weight_ih + (2 * hidden_size + i) * input_size, s_x, input_size);
// Compute dot products for the hidden state contribution.
float r_hidden = dot_vectorized(weight_hh + r_row * hidden_size, s_h, hidden_size);
float z_hidden = dot_vectorized(weight_hh + (hidden_size + i) * hidden_size, s_h, hidden_size);
float n_hidden = dot_vectorized(weight_hh + (2 * hidden_size + i) * hidden_size, s_h, hidden_size);
// Add biases and compute gate activations.
float r_sum = r_input + r_hidden + __ldg(bias_ih + r_row) + __ldg(bias_hh + r_row);
float z_sum = z_input + z_hidden + __ldg(bias_ih + hidden_size + i) + __ldg(bias_hh + hidden_size + i);
float r_val = sigmoidf(r_sum);
float z_val = sigmoidf(z_sum);
float n_sum = n_input + __ldg(bias_ih + 2 * hidden_size + i) + r_val * (n_hidden + __ldg(bias_hh + 2 * hidden_size + i));
float n_val = tanhf(n_sum);
float h_new = (1.0f - z_val) * n_val + z_val * s_h[i];
// Update hidden state.
s_h[i] = h_new;
__syncthreads();
}
// Write the final hidden state back to global memory.
if (i < hidden_size) {
h_out[b * hidden_size + i] = s_h[i];
}
}
// Host function for the fused GRU forward pass (only for num_layers == 1).
torch::Tensor fused_gru_forward(torch::Tensor x, torch::Tensor h0,
int input_size, int hidden_size, int num_layers,
torch::Tensor weight_ih, torch::Tensor weight_hh,
torch::Tensor bias_ih, torch::Tensor bias_hh) {
TORCH_CHECK(num_layers == 1, "Fused GRU kernel supports only num_layers == 1");
c10::cuda::CUDAGuard device_guard(x.device());
int seq_len = x.size(0);
int batch_size = x.size(1);
auto h_out = torch::empty({batch_size, hidden_size}, x.options());
// Launch one block per batch element; blockDim.x = hidden_size.
dim3 grid(batch_size);
dim3 block(hidden_size);
size_t shared_mem_size = (hidden_size + input_size) * sizeof(float);
fused_gru_single_layer_optimized_kernel<<<grid, block, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
h0.data_ptr<float>(),
h_out.data_ptr<float>(),
seq_len,
batch_size,
input_size,
hidden_size,
weight_ih.data_ptr<float>(),
weight_hh.data_ptr<float>(),
bias_ih.data_ptr<float>(),
bias_hh.data_ptr<float>()
);
return h_out.view({1, batch_size, hidden_size});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_gru_forward, "Optimized fused GRU forward kernel for num_layers==1");
}