3.40. GRU Hidden State Calculation

Calculates the final recurrent hidden state by processing an input sequence through multiple layers of a gated recurrent unit (GRU), starting from an initially randomized hidden state.

Progress

Step 3 of 3
Reward
1.00x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize the kernel by ensuring the inner dot‐product routine is inlined and uses the read‐only cache via __ldg. In our updated CUDA kernel we add the __forceinline__ directive to our dot_vectorized routine and wrap the load accesses with __ldg. This should reduce function call overhead and improve memory latency for reading weights. The overall kernel organization remains similar – each block processes one batch element and each thread computes one hidden unit’s recurrent state, while shared memory is used to load the per‐time-step input vector and the previous hidden state. We have kept the necessary synchronizations. These improvements should yield a small performance boost without altering functional correctness.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <math.h>

// Device inline function for sigmoid.
__device__ inline float sigmoidf(float x) {
    return 1.0f / (1.0f + expf(-x));
}

// Optimized and inlined dot product function using vectorized loads and __ldg.
__device__ __forceinline__ float dot_vectorized(const float* a, const float* b, int len) {
    float sum = 0.0f;
    int vec_len = len / 4;  // Assume len is divisible by 4; process remainder separately.
    const float4* a4 = reinterpret_cast<const float4*>(a);
    const float4* b4 = reinterpret_cast<const float4*>(b);
    for (int i = 0; i < vec_len; i++) {
        float4 va = __ldg(a4 + i);
        float4 vb = __ldg(b4 + i);
        sum += va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w;
    }
    for (int i = vec_len * 4; i < len; i++) {
        sum += __ldg(a + i) * __ldg(b + i);
    }
    return sum;
}

// Optimized fused GRU kernel for a single-layer GRU.
// Each block processes one batch element.
// Each thread (indexed by threadIdx.x) computes one element of the hidden state.
// Shared memory layout: first 'hidden_size' floats for hidden state, then 'input_size' floats for input vector.
__global__ void fused_gru_single_layer_optimized_kernel(const float* __restrict__ x,
                                                        const float* __restrict__ h0,
                                                        float* __restrict__ h_out,
                                                        int seq_len, int batch_size, int input_size, int hidden_size,
                                                        const float* __restrict__ weight_ih,
                                                        const float* __restrict__ weight_hh,
                                                        const float* __restrict__ bias_ih,
                                                        const float* __restrict__ bias_hh) {
    int b = blockIdx.x;       // Each block processes one batch element.
    int i = threadIdx.x;      // Each thread processes one hidden unit.
    
    // Shared memory: first hidden_size floats for hidden state, then input_size floats for the input vector.
    extern __shared__ float smem[];
    float* s_h = smem;              // Hidden state buffer.
    float* s_x = smem + hidden_size;  // Input vector buffer.

    // Load initial hidden state for element i.
    if (i < hidden_size) {
        s_h[i] = h0[b * hidden_size + i];
    }
    __syncthreads();

    // Loop over time steps.
    for (int t = 0; t < seq_len; t++) {
        const float* x_ptr = x + t * (batch_size * input_size) + b * input_size;
        // Cooperative loading of input vector into shared memory.
        for (int idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
            s_x[idx] = x_ptr[idx];
        }
        __syncthreads();
        
        // Determine row indices for the three gates.
        int r_row = i;                   // Reset gate row index.
        int z_row = hidden_size + i;       // Update gate row index.
        int n_row = 2 * hidden_size + i;     // New gate row index.
        
        // Compute dot products for the input contribution.
        float r_input = dot_vectorized(weight_ih + r_row * input_size, s_x, input_size);
        float z_input = dot_vectorized(weight_ih + (hidden_size + i) * input_size, s_x, input_size);
        float n_input = dot_vectorized(weight_ih + (2 * hidden_size + i) * input_size, s_x, input_size);
        
        // Compute dot products for the hidden state contribution.
        float r_hidden = dot_vectorized(weight_hh + r_row * hidden_size, s_h, hidden_size);
        float z_hidden = dot_vectorized(weight_hh + (hidden_size + i) * hidden_size, s_h, hidden_size);
        float n_hidden = dot_vectorized(weight_hh + (2 * hidden_size + i) * hidden_size, s_h, hidden_size);
        
        // Add biases and compute gate activations.
        float r_sum = r_input + r_hidden + __ldg(bias_ih + r_row) + __ldg(bias_hh + r_row);
        float z_sum = z_input + z_hidden + __ldg(bias_ih + hidden_size + i) + __ldg(bias_hh + hidden_size + i);
        float r_val = sigmoidf(r_sum);
        float z_val = sigmoidf(z_sum);
        float n_sum = n_input + __ldg(bias_ih + 2 * hidden_size + i) + r_val * (n_hidden + __ldg(bias_hh + 2 * hidden_size + i));
        float n_val = tanhf(n_sum);
        
        float h_new = (1.0f - z_val) * n_val + z_val * s_h[i];
        
        // Update hidden state.
        s_h[i] = h_new;
        __syncthreads();
    }
    // Write the final hidden state back to global memory.
    if (i < hidden_size) {
        h_out[b * hidden_size + i] = s_h[i];
    }
}

// Host function for the fused GRU forward pass (only for num_layers == 1).
torch::Tensor fused_gru_forward(torch::Tensor x, torch::Tensor h0,
                                int input_size, int hidden_size, int num_layers,
                                torch::Tensor weight_ih, torch::Tensor weight_hh,
                                torch::Tensor bias_ih, torch::Tensor bias_hh) {
    TORCH_CHECK(num_layers == 1, "Fused GRU kernel supports only num_layers == 1");
    
    c10::cuda::CUDAGuard device_guard(x.device());
    int seq_len = x.size(0);
    int batch_size = x.size(1);
    
    auto h_out = torch::empty({batch_size, hidden_size}, x.options());
    
    // Launch one block per batch element; blockDim.x = hidden_size.
    dim3 grid(batch_size);
    dim3 block(hidden_size);
    size_t shared_mem_size = (hidden_size + input_size) * sizeof(float);
    
    fused_gru_single_layer_optimized_kernel<<<grid, block, shared_mem_size, c10::cuda::getCurrentCUDAStream()>>>(
        x.data_ptr<float>(),
        h0.data_ptr<float>(),
        h_out.data_ptr<float>(),
        seq_len,
        batch_size,
        input_size,
        hidden_size,
        weight_ih.data_ptr<float>(),
        weight_hh.data_ptr<float>(),
        bias_ih.data_ptr<float>(),
        bias_hh.data_ptr<float>()
    );
    
    return h_out.view({1, batch_size, hidden_size});
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &fused_gru_forward, "Optimized fused GRU forward kernel for num_layers==1");
}