paint-brush
Efficient linear recurrences on deviceby@gating

Efficient linear recurrences on device

by Gating TechnologyJanuary 14th, 2025
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

This research optimizes RG-LRU layers with a custom Pallas kernel for TPU-v3, achieving 3x speedup and 10-20% faster Hawk model training times.
featured image - Efficient linear recurrences on device
Gating Technology HackerNoon profile picture
0-item

Authors:

(1) Soham De, Google DeepMind and with Equal contributions;

(2) Samuel L. Smith, Google DeepMind and with Equal contributions;

(3) Anushan Fernando, Google DeepMind and with Equal contributions;

(4) Aleksandar Botev, Google DeepMind and with Equal contributions;

(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;

(6) Albert Gu, Work done while at Google DeepMind;

(7) Ruba Haroun, Google DeepMind;

(8) Leonard Berrada, Google DeepMind;

(9) Yutian Chen, Google DeepMind;

(10) Srivatsan Srinivasan, Google DeepMind;

(11) Guillaume Desjardins, Google DeepMind;

(12) Arnaud Doucet, Google DeepMind;

(13) David Budden, Google DeepMind;

(14) Yee Whye Teh, Google DeepMind;

(15) David Budden, Google DeepMind;

(16) Razvan Pascanu, Google DeepMind;

(17) Nando De Freitas, Google DeepMind;

(18) Caglar Gulcehre, Google DeepMind.

1 Introduction

2 Model Architecture

3 Recurrent Models Scale as Efficiently as Transformers

3.1. Scaling curves

3.2. Evaluation on downstream tasks

4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training

4.2. Efficient linear recurrences on device

4.3. Training speed on longer sequences

5. Inference Speed

5.1. A simple model of the decode step

5.2. Results

6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts

6.2. Copy and retrieval capabilities

7. Related Works

8. Conclusion, Acknowledgements, and References


A. RG-LRU Recurrence Gate

B. Complex-Gated Linear Recurrent Unit (CG-LRU)

C. Model Scale Hyper-Parameters

D. Efficient Linear Recurrences on Device

E. The Local Attention Window Size of Griffin

F. Inference Speeds

G. Improving Next Token Prediction with Longer Contexts: Additional Results

H. Additional Details of the Copy and Retrieval Tasks

4.2. Efficient linear recurrences on device

Current deep learning accelerators are optimized for classical architectures which are composed largely of matrix multiplications and convolutions. These operations have a high FLOPs-to-byte ratio, motivating the development of specialized hardware units like Nvidia GPUs’ TensorCores (Markidis et al., 2018) and Google TPUs’ MXUs (Norrie et al., 2021; Jouppi et al., 2021, 2023). Classical RNNs also benefit from this due to their dense recurrence matrices. In contrast, our proposed RG-LRU layer, like other diagonal RNN models, has a low FLOPs-to-byte ratio. This fundamental difference poses a computational challenge, as existing accelerators lack optimization for such workloads. Since we run all our experiments on TPU-v3, we focus on developing an efficient implementation tailored to this device[3].



A custom linear scan To address this we have written a custom Pallas kernel for the computation of eq. (4) using a linear scan. This allows us to minimize memory transfers, by keeping the hidden state


Figure 3 | Training durations per step computed relative to our MQA baseline at 2K sequence length as we vary the model size and sequence length for Griffin and MQA. Let us note that as we increase the sequence length we lower the batch size proportionally, such that the total number of tokens per batch stays fixed.


in VMEM all the time, and also to perform memory transfers in larger chunks rather than one at a time. In practice, this translates to almost 3x speed up over the native Jax implementation of the linear scan. Additionally, we observe 10-20% lower training times per step of the full Hawk model, relative to the same model using the native Jax implementation (see Appendix D.2 for more details.)



This paper is available on arxiv under CC BY 4.0 DEED license.


[3] The conclusions drawn here do not necessarily apply to other accelerators.