3.34. Vanilla RNN Hidden State Update
Calculates the updated hidden state for a recurrent cell by concatenating the current input with the prior hidden state, applying a linear transformation followed by a tanh activation, and then projecting the resulting state to compute an output.
Progress
Optimization Timeline
Step 3 Details
Reasoning
Our candidate kernel was correct but slower than the native counterpart because it computed each dot product sequentially in a single thread. We can improve performance by exploiting parallelism in the dot product calculation. In our new design, we assign one CUDA block per output element (i.e. per (batch, hidden) pair) so that the dot product across the concatenated vector (of length K = input_size + hidden_size) is computed collaboratively by all threads in the block. Each thread computes a partial sum over a strided range of the dot product and then we perform an in-block parallel reduction in shared memory. Finally, thread 0 adds the bias and applies the tanh activation. This approach should increase occupancy and reduce the overall dot product latency. Before ending the reasoning, I ask: “Have my improvements genuinely helped?” Yes; by parallelizing the dot product over 256 threads per output element and reducing the sequential work per thread, we are likely to see better performance compared to our earlier kernel that unrolled loops in a single thread. The improvements are genuine micro-optimizations without any "cheating". Below are the updated Python and CUDA files.
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
// Optimized CUDA kernel for the fused RNN cell using parallel reduction.
// For each batch element i and hidden unit j, the kernel computes:
// new_hidden[i, j] = tanh( b[j] + sum_{k=0}^{K-1} (combined[i, k] * W[j, k]) )
// where combined[i, k] is defined as:
// if (k < input_size) then x[i, k]
// else hidden_prev[i, k - input_size]
// Each CUDA block is responsible for one output element (i.e. one (batch, hidden) pair) and uses
// all its threads to compute the dot product over k in parallel.
__global__ void rnn_cell_kernel_optimized(const float *__restrict__ x,
const float *__restrict__ hidden_prev,
const float *__restrict__ W,
const float *__restrict__ b,
float *new_hidden,
int batch,
int input_size,
int hidden_size) {
int K = input_size + hidden_size;
// Each block corresponds to one output element: use gridIdx.x to index (batch, hidden)
int blockId = blockIdx.x;
int i = blockId / hidden_size; // batch index
int j = blockId % hidden_size; // hidden unit index
extern __shared__ float sdata[];
float sum = 0.0f;
// Each thread computes a partial sum over indices from k = threadIdx.x to K in steps of blockDim.x.
for (int k = threadIdx.x; k < K; k += blockDim.x) {
float combined_val;
if (k < input_size) {
combined_val = x[i * input_size + k];
} else {
combined_val = hidden_prev[i * hidden_size + (k - input_size)];
}
// Weight for neuron j is stored in row-major order: W[j*K + k]
sum += combined_val * W[j * K + k];
}
// Store each thread's partial sum into shared memory.
sdata[threadIdx.x] = sum;
__syncthreads();
// Parallel reduction in shared memory.
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
sdata[threadIdx.x] += sdata[threadIdx.x + s];
}
__syncthreads();
}
// Thread 0 adds the bias and applies tanh activation.
if (threadIdx.x == 0) {
float dot = b[j] + sdata[0];
new_hidden[i * hidden_size + j] = tanhf(dot);
}
}
torch::Tensor forward(torch::Tensor x, torch::Tensor hidden_prev, torch::Tensor W, torch::Tensor b, int input_size, int hidden_size) {
// x: [batch, input_size]
// hidden_prev: [batch, hidden_size]
// W: [hidden_size, input_size + hidden_size]
// b: [hidden_size]
// Output: new_hidden: [batch, hidden_size]
c10::cuda::CUDAGuard device_guard(x.device());
int batch = x.size(0);
auto new_hidden = torch::empty({batch, hidden_size}, x.options());
// Launch one block per output element: total blocks = batch * hidden_size.
int total_blocks = batch * hidden_size;
int threads = 256;
size_t shared_mem = threads * sizeof(float);
rnn_cell_kernel_optimized<<<total_blocks, threads, shared_mem, c10::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(),
hidden_prev.data_ptr<float>(),
W.data_ptr<float>(),
b.data_ptr<float>(),
new_hidden.data_ptr<float>(),
batch,
input_size,
hidden_size
);
return new_hidden;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "Optimized fused RNN cell: computes tanhf(W*[x; h] + b) using parallel reduction");
}