Listen to this story
Research and publications on cutting-edge rendering technologies, shaping 2d & 3d visual experiences across industries.
Part of HackerNoon's growing list of open-source research papers, promoting free access to academic material.
Authors:
(1) Albert Gu, Machine Learning Department, Carnegie Mellon University with Equal contribution (agu@cs.cmu.edu);
(2) Tri Dao, Department of Computer Science, Princeton University with Equal contribution (tri@tridao.me).
3 Selective State Space Models and 3.1 Motivation: Selection as a Means of Compression
3.2 Improving SSMs with Selection
3.3 Efficient Implementation of Selective SSMs
3.4 A Simplifed SSM Architecture
3.5 Properties of Selection Mechanisms
4 Empirical Evaluation and 4.1 Synthetic Tasks
4.4 Audio Modeling and Generation
4.5 Speed and Memory Benchmarks
6 Conclusion, Acknowledgments and References
A Discussion: Selection Mechanism
B Related Work and B.1 S4 Variants and Derivatives
B.4 Linear Attention and B.5 Long Context Models
D Hardware-aware Algorithm For Selective SSMs
E Experimental Details and Additional Results and E.1 Synthetic Tasks
All models were trained on the Pile.
Model Sizes. Table 12 specifies the model sizes we use for scaling laws. This is taken directly from the GPT3 specifications (Brown et al. 2020), with very minor modifications. First, we changed the batch size of the 1.3B model from 1M tokens to 0.5M tokens, since we did not use enough parallelization to require the larger batch size. Second, we changed the number of training steps and total tokens to roughly match Chinchilla scaling laws (Hoffmann et al. 2022), which specify that training tokens should increase proportionally to model size.
Training Recipes. All models used the AdamW optimizer with
• gradient clip value 1.0
• weight decay 0.1
• no dropout
• linear learning rate warmup with cosine decay
By default, the peak learning rate is the GPT3 specification.
We give several models an “improved recipe”, inspired by changes adopted by popular large language models such as PaLM (Chowdhery et al. 2023) and LLaMa (Touvron et al. 2023). These include:
Architecture and Training Details. Our models are:
• Transformer: The standard Transformer based on GPT3 (Table 12).
• Transformer++: A Transformer with an improved architecture, namely rotary positional encodings (Su et al. 2021) and SwiGLU MLP (Shazeer 2020), and the improved training recipe above.
• Hyena: Interleaving a Hyena block (the H3 block with S4 replaced by a global convolution parameterized by an MLP) with standard MLP blocks. The MLP blocks have expansion factor 2 instead of 4 and the number of layers is correspondingly increased by 1.5× to preserve parameter count.
• H3++: The H3 architecture with a few modifications, including (i) using the same “thin” Hyena dimensions above (ii) the improved training recipe above (iii) a linear attention head dimension of 8.
• RWKV: The default RWKV model from B. Peng et al. (2023), including its modified MLP block. We also used as much of its specified training recipe as possible, such as increasing the learning rates by 2× or 3× on certain parameters.
• RetNet: The default RetNet model from Y. Sun et al. (2023). We also gave it the improved training recipe above.
• Mamba: The standard Mamba architecture, with the improved training recipe.
We perform additional ablations on the architecture using the same protocol as the 2k context length scaling laws in Figure 4 (Left).
Mamba Architecture: Interleaving Blocks. We test the effect of different architectural blocks combined with the Mamba block. We focus on the viewpoint that the Mamba block is simply the standard SwiGLU block with an extra 햼허헇헏 → 햲햲햬 path added. This leads to two natural ablations:
• What if the Mamba block is interleaved with a standard MLP block, instead of stacked homogenously? This can also be interpreted as taking Mamba and removing half of the SSMs.
• What if the Mamba block is interleaved with MHA (multi-head attention) blocks? This can also be interpreted as taking a Transformer with SwiGLU MLPs (i.e. what we call Transformer++) and simply adding SSMs to the MLP blocks.
Figure 9 (Right) shows these variants compared to the original (homogenous) Mamba architecture. Interestingly, neither change matters too much. The Mamba-MLP architecture is only slightly worse, and still better than all models except Transformer++. The Mamba-MHA architecture is only slightly better, which is somewhat surprising in light of the fact that many recent works have found that combining (LTI) SSMs with Attention can lead to substantial improvements (Dao, Fu, Saab, et al. 2023; Fathi et al. 2023; Fathullah et al. 2023; Saon, Gupta, and Cui 2023; Zuo et al. 2022).
H3 Architecture: Training Recipes. Next we ablate differences between the Hyena and H3++ models, our weakest and strongest models outside of Transformer++ and Mamba, particularly to isolate the effect of training recipes.
• Hyena: The Hyena block with its original architecture and GPT3 training recipe (same as Figure 4).
• Hyena+: The same architecture but with the improved training recipe described above.
• H3+: The same architecture as Hyena+ but with the Hyena convolution kernel swapped out for S4D convolution kernel.
• H3++: The same as H3+, but with a linear attention head dimension of 8. This increases computation inside the SSM recurrence but does not increase parameters.
Our general convention is that “Model+” represents the base model with the improved training recipe, and “Model++” also allows for architectural changes.
Figure 9 (Right) shows that
• A large improvement is achieved by the improved training recipe, which was used for many of the models in the main Figure 4 (RetNet, H3++, Transformer++, Mamba).
• The choice of the inner LTI SSM does not matter (e.g. Hyena vs. S4), consistent with findings throughout this paper.
• The head dimension expansion improves performance, consistent with one of our main themes that expanded state dimension improves performance for SSMs (Section 3).
Figure 9: (Scaling laws: extra ablations.) (Left) Instead of (Right) Instead of
This pretraining procedure is the same as the scaling law protocol, but extended to 300B tokens. For the 1.3B model, we use a batch size of 1M tokens to be consistent with the GPT3 specifications. We report the perplexity on the Pile validation set, and for this metric only compare to models trained on the same dataset and with the same tokenizer, in particular Pythia and RWKV.
For downstream evaluation, we use the LM evaluation harness from EleutherAI (L. Gao, Tow, et al. 2021), as done by most work in this area. We evaluate on the following tasks/datasets that measure common sense reasoning:
• LAMBADA (Paperno et al. 2016).
• HellaSwag (Zellers et al. 2019).
• PIQA (Bisk et al. 2020).
• ARC-challenge (P. Clark et al. 2018).
• ARC-easy: an easy subset of ARC-challenge.
• WinoGrande (Sakaguchi et al. 2021).
We report accuracy for LAMBADA, WinoGrande, PIQA, and ARC-easy, and accuracy normalized by sequence length for HellaSwag and ARC-challenge (since normalized accuracy is higher for almost all models for these task).
This paper is available on arxiv under CC BY 4.0 DEED license.