Authors:
(1) Soham De, Google DeepMind and with Equal contributions;
(2) Samuel L. Smith, Google DeepMind and with Equal contributions;
(3) Anushan Fernando, Google DeepMind and with Equal contributions;
(4) Aleksandar Botev, Google DeepMind and with Equal contributions;
(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;
(6) Albert Gu, Work done while at Google DeepMind;
(7) Ruba Haroun, Google DeepMind;
(8) Leonard Berrada, Google DeepMind;
(9) Yutian Chen, Google DeepMind;
(10) Srivatsan Srinivasan, Google DeepMind;
(11) Guillaume Desjardins, Google DeepMind;
(12) Arnaud Doucet, Google DeepMind;
(13) David Budden, Google DeepMind;
(14) Yee Whye Teh, Google DeepMind;
(15) David Budden, Google DeepMind;
(16) Razvan Pascanu, Google DeepMind;
(17) Nando De Freitas, Google DeepMind;
(18) Caglar Gulcehre, Google DeepMind.
3 Recurrent Models Scale as Efficiently as Transformers
3.2. Evaluation on downstream tasks
4.2. Efficient linear recurrences on device
4.3. Training speed on longer sequences
5.1. A simple model of the decode step
6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts
6.2. Copy and retrieval capabilities
8. Conclusion, Acknowledgements, and References
B. Complex-Gated Linear Recurrent Unit (CG-LRU)
C. Model Scale Hyper-Parameters
D. Efficient Linear Recurrences on Device
E. The Local Attention Window Size of Griffin
G. Improving Next Token Prediction with Longer Contexts: Additional Results
H. Additional Details of the Copy and Retrieval Tasks
Current deep learning accelerators are optimized for classical architectures which are composed largely of matrix multiplications and convolutions. These operations have a high FLOPs-to-byte ratio, motivating the development of specialized hardware units like Nvidia GPUs’ TensorCores (Markidis et al., 2018) and Google TPUs’ MXUs (Norrie et al., 2021; Jouppi et al., 2021, 2023). Classical RNNs also benefit from this due to their dense recurrence matrices. In contrast, our proposed RG-LRU layer, like other diagonal RNN models, has a low FLOPs-to-byte ratio. This fundamental difference poses a computational challenge, as existing accelerators lack optimization for such workloads. Since we run all our experiments on TPU-v3, we focus on developing an efficient implementation tailored to this device[3].
A custom linear scan To address this we have written a custom Pallas kernel for the computation of eq. (4) using a linear scan. This allows us to minimize memory transfers, by keeping the hidden state
in VMEM all the time, and also to perform memory transfers in larger chunks rather than one at a time. In practice, this translates to almost 3x speed up over the native Jax implementation of the linear scan. Additionally, we observe 10-20% lower training times per step of the full Hawk model, relative to the same model using the native Jax implementation (see Appendix D.2 for more details.)
This paper is available on arxiv under CC BY 4.0 DEED license.
[3] The conclusions drawn here do not necessarily apply to other accelerators.