3.35. Sequential LSTM Calculator

Calculates final output predictions from sequential input data by initializing random hidden and cell states, processing the sequence through multiple recurrent layers, and applying a linear transformation to the output of the last time step.

Progress

Step 3 of 3
Reward
1.02x

Optimization Timeline

Step 3 Details

Reasoning

We can further optimize our fused fully connected projection kernel by reducing loop overhead and memory latency in the dot‐product computation. In our previous version, each thread iterated over hidden_size elements. Since hidden_size is often a multiple of 4, we can process four floats at a time using float4 and employ loop unrolling over several iterations to reduce the number of loop iterations. In addition, applying the __ldg intrinsic explicitly when loading from global memory will ensure the values are loaded via the read‐only cache, further enhancing throughput. These changes reduce the per-thread overhead and latency of the dot product computation, which in turn should yield a small but genuine speedup. Have my improvements genuinely helped? Yes, by employing vectorized loads with manual loop unrolling and using __ldg for caching, the kernel should show better performance without changing its overall algorithm.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

// Fused kernel for fully connected projection on the last time step of LSTM output.
// Each thread computes one output element (dot product between last time-step vector and weight row, plus bias).
// - lstm_out: [batch_size, seq_length, hidden_size]
// - weight: [output_size, hidden_size]
// - bias: [output_size]
// - output: [batch_size, output_size]
__global__ void fc_forward_kernel(const float* __restrict__ lstm_out,
                                  const float* __restrict__ weight,
                                  const float* __restrict__ bias,
                                  float* __restrict__ output,
                                  int batch_size,
                                  int seq_length,
                                  int hidden_size,
                                  int output_size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int total = batch_size * output_size;
    if (idx < total) {
        int i = idx / output_size;  // batch index
        int j = idx % output_size;  // output feature index

        // Compute the offset for the last time step of sample i.
        int last_offset = i * seq_length * hidden_size + (seq_length - 1) * hidden_size;
        float sum = 0.0f;

        // If hidden_size is divisible by 4, use vectorized loads with loop unrolling.
        if (hidden_size % 4 == 0) {
            int vec_count = hidden_size / 4;
            const float4* last_vec = reinterpret_cast<const float4*>(&lstm_out[last_offset]);
            const float4* weight_vec = reinterpret_cast<const float4*>(&weight[j * hidden_size]);
            int k = 0;
            // Unroll loop in groups of 4 iterations.
            for (; k + 3 < vec_count; k += 4) {
                float4 lv0 = __ldg(&last_vec[k]);
                float4 wv0 = __ldg(&weight_vec[k]);
                float4 lv1 = __ldg(&last_vec[k + 1]);
                float4 wv1 = __ldg(&weight_vec[k + 1]);
                float4 lv2 = __ldg(&last_vec[k + 2]);
                float4 wv2 = __ldg(&weight_vec[k + 2]);
                float4 lv3 = __ldg(&last_vec[k + 3]);
                float4 wv3 = __ldg(&weight_vec[k + 3]);
                sum += lv0.x * wv0.x + lv0.y * wv0.y + lv0.z * wv0.z + lv0.w * wv0.w;
                sum += lv1.x * wv1.x + lv1.y * wv1.y + lv1.z * wv1.z + lv1.w * wv1.w;
                sum += lv2.x * wv2.x + lv2.y * wv2.y + lv2.z * wv2.z + lv2.w * wv2.w;
                sum += lv3.x * wv3.x + lv3.y * wv3.y + lv3.z * wv3.z + lv3.w * wv3.w;
            }
            // Process remaining vectorized elements.
            for (; k < vec_count; ++k) {
                float4 lv = __ldg(&last_vec[k]);
                float4 wv = __ldg(&weight_vec[k]);
                sum += lv.x * wv.x + lv.y * wv.y + lv.z * wv.z + lv.w * wv.w;
            }
        } else {
            // Fallback to scalar computation.
            for (int k = 0; k < hidden_size; k++) {
                sum += __ldg(&lstm_out[last_offset + k]) * __ldg(&weight[j * hidden_size + k]);
            }
        }
        output[idx] = sum + __ldg(&bias[j]);
    }
}

torch::Tensor fc_forward(torch::Tensor lstm_out, torch::Tensor fc_weight, torch::Tensor fc_bias) {
    // Use c10::cuda::CUDAGuard to ensure the current device is set.
    c10::cuda::CUDAGuard device_guard(lstm_out.device());
    
    TORCH_CHECK(lstm_out.dim() == 3, "lstm_out should have shape [batch_size, seq_length, hidden_size]");
    TORCH_CHECK(fc_weight.dim() == 2, "fc_weight should have shape [output_size, hidden_size]");
    TORCH_CHECK(fc_bias.dim() == 1, "fc_bias should have shape [output_size]");
    
    int batch_size = lstm_out.size(0);
    int seq_length = lstm_out.size(1);
    int hidden_size = lstm_out.size(2);
    int output_size = fc_weight.size(0);
    
    auto output = torch::empty({batch_size, output_size}, lstm_out.options());
    
    int total = batch_size * output_size;
    int threads = 256;
    int blocks = (total + threads - 1) / threads;
    
    fc_forward_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
        lstm_out.data_ptr<float>(),
        fc_weight.data_ptr<float>(),
        fc_bias.data_ptr<float>(),
        output.data_ptr<float>(),
        batch_size,
        seq_length,
        hidden_size,
        output_size
    );
    
    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fc_forward", &fc_forward, "Fused extraction of last time-step and fully connected projection with vectorized unrolling");
}