paint-brush

This story draft by @serialization has not been reviewed by an editor, YET.

How Mamba and Hyena Are Changing the Way AI Learns and Remembers

featured image - How Mamba and Hyena Are Changing the Way AI Learns and Remembers
The Serialization Publication HackerNoon profile picture
0-item

Authors:

(1) Albert Gu, Machine Learning Department, Carnegie Mellon University and with equal contribution;

(2) Tri Dao, Department of Computer Science, Princeton University and with equal contribution.

Table of Links

Abstract and 1 Introduction

2 State Space Models

3 Selective State Space Models and 3.1 Motivation: Selection as a Means of Compression

3.2 Improving SSMs with Selection

3.3 Efficient Implementation of Selective SSMs

3.4 A Simplified SSM Architecture

3.5 Properties of Selection Mechanisms

3.6 Additional Model Details

4 Empirical Evaluation and 4.1 Synthetic Tasks

4.2 Language Modeling

4.3 DNA Modeling

4.4 Audio Modeling and Generation

4.5 Speed and Memory Benchmarks

4.6 Model Ablations

5 Discussion

6 Conclusion and References


A Discussion: Selection Mechanism

B Related Work

C Mechanics of Selective SSMs

D Hardware-aware Algorithm For Selective SSMs

E Experimental Details and Additional Results

E Experimental Details and Additional Results

E.1 Synthetic Tasks

Selective Copying. Our setting is on sequences of length 4096, with a vocab size of 16 possible tokens (including the white “noise” token from Figure 2) and requiring models to memorize 16 “data” tokens. We use 2 layer models with a model dimension of D = 64.


Models are trained for 400K steps at a constant learning rate of 0.0001 with a batch size of 64.


Induction Heads. Training consists of randomly generating data every step, with a batch size of 8. We choose an “epoch” size of 8192 steps, and track the accuracy on fixed validation sets (also randomly generated) of each target sequence length. For the MHA-Abs and Mamba models, results are reported after the 25th epoch (8192 × 25 = 204800 steps). For the MHA-RoPE and MHA-xPos models, results are reported after the 50th epoch (8192 × 50 = 409600 steps). For the LTI H3 and Hyena models, results are reported after the 10th epoch (81920 steps) because they had converged by then and failed to improve further.


Table 12: (Scaling Law Model Sizes.) Our model sizes and hyperparameters for scaling experiments. (Model dimension and number of heads applies only to Transformer models.)


We use the Adam optimizer with no weight decay. All models are trained at constant learning rates 2e − 4 and 1e − 3, and the better results are reported for each model (2e − 4 for all models except Mamba). The attention and Hyena models did not learn at LR 1e − 3. H3 learned at both LRs, but interestingly generalized better to shorter sequences at the smaller LR of 2푒 − 4. Mamba learned at both LRs, but extrapolated better at the larger LR of 1e − 3.

E.2 Language Modeling

E.2.1 Scaling Law Details


All models were trained on the Pile.


Model Sizes. Table 12 specifies the model sizes we use for scaling laws. This is taken directly from the GPT3 specifications (Brown et al. 2020), with very minor modifications. First, we changed the batch size of the 1.3B model from 1M tokens to 0.5M tokens, since we did not use enough parallelization to require the larger batch size. Second, we changed the number of training steps and total tokens to roughly match Chinchilla scaling laws (Hoffmann et al. 2022), which specify that training tokens should increase proportionally to model size.


Training Recipes. All models used the AdamW optimizer with


• gradient clip value 1.0


• weight decay 0.1


• no dropout


• linear learning rate warmup with cosine decay


By default, the peak learning rate is the GPT3 specification.


We give several models an “improved recipe”, inspired by changes adopted by popular large language models such as PaLM (Chowdhery et al. 2023) and LLaMa (Touvron et al. 2023). These include:



Architecture and Training Details. Our models are:


• Transformer: The standard Transformer based on GPT3 (Table 12).


• Transformer++: A Transformer with an improved architecture, namely rotary positional encodings (Su et al. 2021) and SwiGLU MLP (Shazeer 2020), and the improved training recipe above.


• Hyena: Interleaving a Hyena block (the H3 block with S4 replaced by a global convolution parameterized by an MLP) with standard MLP blocks. The MLP blocks have expansion factor 2 instead of 4 and the number of layers is correspondingly increased by 1.5× to preserve parameter count.


• H3++: The H3 architecture with a few modifications, including (i) using the same “thin” Hyena dimensions above (ii) the improved training recipe above (iii) a linear attention head dimension of 8.


• RWKV: The default RWKV model from B. Peng et al. (2023), including its modified MLP block. We also used as much of its specified training recipe as possible, such as increasing the learning rates by 2× or 3× on certain parameters.


• RetNet: The default RetNet model from Y. Sun et al. (2023). We also gave it the improved training recipe above.


• Mamba: The standard Mamba architecture, with the improved training recipe.


E.2.2 Additional Scaling Law Ablations


We perform additional ablations on the architecture using the same protocol as the 2k context length scaling laws in Figure 4 (Left).


Mamba Architecture: Interleaving Blocks. We test the effect of different architectural blocks combined with the Mamba block. We focus on the viewpoint that the Mamba block is simply the standard SwiGLU block with an extra conv → SSM path added. This leads to two natural ablations:


• What if the Mamba block is interleaved with a standard MLP block, instead of stacked homogenously? This can also be interpreted as taking Mamba and removing half of the SSMs.


• What if the Mamba block is interleaved with MHA (multi-head attention) blocks? This can also be interpreted as taking a Transformer with SwiGLU MLPs (i.e. what we call Transformer++) and simply adding SSMs to the MLP blocks.


Figure 9 (Right) shows these variants compared to the original (homogenous) Mamba architecture. Interestingly, neither change matters too much. The Mamba-MLP architecture is only slightly worse, and still better than all models except Transformer++. The Mamba-MHA architecture is only slightly better, which is somewhat surprising in light of the fact that many recent works have found that combining (LTI) SSMs with Attention can lead to substantial improvements (Dao, Fu, Saab, et al. 2023; Fathi et al. 2023; Fathullah et al. 2023; Saon, Gupta, and Cui 2023; Zuo et al. 2022).


H3 Architecture: Training Recipes. Next we ablate differences between the Hyena and H3++ models, our weakest and strongest models outside of Transformer++ and Mamba, particularly to isolate the effect of training recipes.


• Hyena: The Hyena block with its original architecture and GPT3 training recipe (same as Figure 4). • Hyena+: The same architecture but with the improved training recipe described above.


• H3+: The same architecture as Hyena+ but with the Hyena convolution kernel swapped out for S4D convolution kernel.


• H3++: The same as H3+, but with a linear attention head dimension of 8. This increases computation inside the SSM recurrence but does not increase parameters.


Our general convention is that “Model+” represents the base model with the improved training recipe, and “Model++” also allows for architectural changes.


Figure 9 (Right) shows that


• A large improvement is achieved by the improved training recipe, which was used for many of the models in the main Figure 4 (RetNet, H3++, Transformer++, Mamba).


• The choice of the inner LTI SSM does not matter (e.g. Hyena vs. S4), consistent with findings throughout this paper.


• The head dimension expansion improves performance, consistent with one of our main themes that expanded state dimension improves performance for SSMs (Section 3).


Figure 9: (Scaling laws: extra ablations.) (Left) Instead of (Right) Instead of


E.2.3 Downstream Evaluation Details


This pretraining procedure is the same as the scaling law protocol, but extended to 300B tokens. For the 1.3B model, we use a batch size of 1M tokens to be consistent with the GPT3 specifications. We report the perplexity on the Pile validation set, and for this metric only compare to models trained on the same dataset and with the same tokenizer, in particular Pythia and RWKV.


For downstream evaluation, we use the LM evaluation harness from EleutherAI (L. Gao, Tow, et al. 2021), as done by most work in this area. We evaluate on the following tasks/datasets that measure common sense reasoning:


• LAMBADA (Paperno et al. 2016).


• HellaSwag (Zellers et al. 2019).


• PIQA (Bisk et al. 2020).


• ARC-challenge (P. Clark et al. 2018).


• ARC-easy: an easy subset of ARC-challenge.


• WinoGrande (Sakaguchi et al. 2021).


We report accuracy for LAMBADA, WinoGrande, PIQA, and ARC-easy, and accuracy normalized by sequence length for HellaSwag and ARC-challenge (since normalized accuracy is higher for almost all models for these task).

E.3 DNA Modeling

E.3.1 Pretraining Details


We describe the dataset and training procedure of the HG38 pretraining task in more detail.



E.3.2 Scaling: Model Size Details


Models. The models we consider are:


• Transformer++: a Transformer with improved architecture, notably the usage of RoPE positional encodings (Su et al. 2021). Informally, we found these to be noticeably better than vanilla positional encodings from (Vaswani et al. 2017).


• HyenaDNA: the Hyena model from Nguyen, Poli, et al. (2023) and Poli et al. (2023), which is roughly a Transformer with the MHA block replaced by an H3 block using a global convolution parameterized by an MLP.


• Mamba: the standard Mamba architecture.


Model Sizes. We use the following model sizes.



Note that the number of blocks for Mamba is doubled, because one Transformer “layer” includes both the MHA and MLP blocks (and similarly for Hyena), which requires two Mamba blocks to match parameters (Section 3.4).


Training. For each model (Transformer++, HyenaDNA, Mamba), we swept the learning rate across {1e − 3, 2e − 3, 4e − 3, 8e − 3}. The optimal Transformer and HyenaDNA learning rates were 2e-3 across all sizes. The optimal Mamba learning rate was 8e-3; note that Mamba performed better than baselines with matched learning rates (2e-3), but was more stable and improved even more at higher learning rates. (Furthermore, as this LR is on the upper range of the sweep, it is possible that our results are still suboptimal.)


Note that, in contrast to standard LM scaling laws (Table 12), our LR held constant across model sizes for simplicity. The optimal LR should go down for larger models, but we didn’t find a noticeable effect at the small model sizes (at most a few million parameters) we considered.


E.3.3 Scaling: Context Length Details



The learning rate used was 0.008 for Mamba and 0.001 for HyenaDNA; we initially attempted to use the same learning rate of 0.002 from the previous section for HyenaDNA, but found that it was unstable at the longest context length.



Unlike HyenaDNA, we always control for the number of tokens per gradient update, so the batch size is successively halved as the sequence lengths are doubled in each stage.



Remark E.1. We also note that the schedule was not tuned, and we never experimented with turning of sequence length warmup for these pretraining experiments. We later found that SLW did not help noticeably for audio pretraining at similar lengths (Section 4.4), and it is possible that it is not necessary for DNA pretraining either.


E.3.4 Species (Great Apes) Classification


Models are causal and therefore only the last element (across the sequence length) of the model’s output is used for the classification head. Note that we control for the total number of elements in the loss function per gradient step. The pretraining objective includes all positions across the sequence length, so that batch_size×sequence_length is held constant; in other words, the batch size decreases as the sequence length increases. However, for a classification task, since only the last position enters the loss, the batch size itself is held constant. Note that this also means that fine-tuning models with longer sequence lengths is more computationally expensive.


Training consists of 10 epochs, each of which has 1024 gradient steps. Each gradient step uses batch size 64, which are all independently randomly drawn by uniformly picking a species, uniformly picking a chromosome, and then uniformly picking a contiguous segment of DNA.



Results for the Species classification task are in Table 13.

E.4 Audio Details

E.4.1 YouTubeMix Audio Pretraining


Model. We use a model with 3 blocks per stage (3 × 5 = 15 total Mamba blocks), pooling factor D = 16, and outer dimension p = 64, for about 3.5M parameters.


Dataset. The data is mu-law encoded at 8 bits, so the model is modeling discrete tokens with a vocab size of 256. The data


The dataset consists of clips of up to 1 minute long, or length 960000, which is subsampled and divided into segments of any desired sequence length. Since the architecture involves two stages of pooling by a factor of 16,


Table 14: YouTubeMix length scaling sequence lengths and batch sizes.


Figure 10: (Audio Pretraining (YouTubeMix) Ablations.) As a uniformly-sampled “continuous” signal modality, audio waveforms actually benefit from LTI models which have matching inductive bias. (Left) Homogenous models (all blocks have the same parameterization) (Right) Only the center U-Net blocks are ablated; the outer blocks are Mamba-S4. Purple line is same as figure on left.


and we want the resulting sequence length to be a a multiple of 8 for hardware efficiency, the longest possible sequence is 468 × 2048 = 958464. The rest of our sequence lengths are defined by successively halving this and rounding up to the nearest multiple of 2048.


Table 14 lists the specifications used in Figure 7. Beyond the varying batch sizes, the number of valid segments in the training set varied between different sequence lengths (e.g. the number of training steps per epoch was not constant for different points in the graph), which may have contributed to kinks in the scaling curves.


Training. Models were trained for 200K training steps with a maximum learning rate of 0.002, 20K (10%) warmup steps, and weight decay 0.1 (similar to our general pretraining recipe across domains).



Figure 10 shows that the change from S4 → S6 (i.e. the selection mechanism) is not always beneficial. On long-form audio waveforms, it in fact significantly hampers performance, which may be intuitive from the point of view that audio is uniformly sampled and very smooth, and therefore benefits from continuous linear time-invariant (LTI) methods. After ablating away the selection mechanism, note that the resulting model is the S4 layer inside the Mamba block. To disambiguate, we call this Mamba-S4 as opposed the default Mamba architecture Mamba-S6.


However, on the right side, we keep the outer layers of the U-Net Mamba-S4 and ablate only the inner layers. The performance differences shrink dramatically; this reinforces the hypothesis that layers closer to the raw audio signal should be LTI, but once they are “tokenized” and compressed by the outer layers, the inner layers no longer need to be LTI. In this setting however, the real-valued SSM still underperforms the complex-valued one.


E.4.2 SC09 Speech Generation


Autoregressive training largely followed the autoregressive language modeling protocol, such as



We used a learning rate of 0.002 and 200000 training steps at a batch size of 16.


The large Mamba model in Table 4 has 15 layers per stage with an outer dimension of D = 96 and pooling factor 4. We note that this dataset is small (training went through 100 epochs) and for this large model, there was significant overfitting of the BPB or NLL. However, automated metrics of generated samples continually improving throughout training.


E.5 Efficiency Benchmark

Scan Operation. We compare the core operation of selective SSMs, which is the parallel scan (Section 3.3), against convolution and attention, measured on an A100 80GB PCIe GPU. Note that these do not include the cost of other operations outside of this core operation, such as computing the convolutional kernel in global-convolution models, or computing the QKV projections in attention.



Our scan implementation fuses the discretization step and the parallel scan, avoiding the cost of materializing all the large parameters in HBM.


For convolution, we use the standard implementation in PyTorch, which separately performs FFTs on the inputs and the filters, multiply them in frequency domain, then performs an inverse FFT to obtain the result. The theoretical complexity is O(L log(L)) for sequence length L.


For attention, we compare against the fastest implementation that we are aware of (FlashAttention-2 (Dao 2023)), with causal mask. Note that FlashAttention-2 with causal mask is about 1.7× faster than without causal mask, since approximately only half of the attention entries are computed.



End-to-end Inference. We measure the inference throughput of a Mamba 1.4B model and an untrained Mamba 6.9B model, against a standard Transformer (GPT3 architecture) at 1.3B and 6.7B size. We use the standard Transformer implementation in the Huggingface transformers library.


We set the prompt length to be 2048 and the generation length to be 128. We vary the batch size from 1, 2, 4, 8, 16, 32, 64, to 128, and measure time time taken to generate 128 tokens. We then calculate the throughput (tokens/s) as batch size × 128∕time taken. We repeat the measurements 3 times and take the average. Measurements are done on an A100 80GB PCIe GPU.


Memory Benchmark. The memory usage simply scales proportionally to the size of the activation tensors, as with most deep sequence models. We report measurements of the training memory requirements of 125M models


Table 15: (Memory benchmark.) Mamba’s memory footprint is comparable to the most optimized Transformer. Results for 125M models.


on 1 A100 80GB GPU. Each batch consists of sequences of length 2048. We compare to the most memory-efficient Transformer implementation we are aware of (with kernel fusion from torch.compile and with FlashAttention-2). Table 15 shows that Mamba’s memory requirement is comparable to a similar-sized Transformer with an extremely optimized implementation, and we expect further improvement in Mamba’s memory footprint in the future.


This paper is available on arxiv under CC BY 4.0 DEED license.


L O A D I N G
. . . comments & more!

About Author

The Serialization Publication HackerNoon profile picture
The Serialization Publication@serialization
We cover the most cutting edge academic research and expert blog posts on serialization. Also big fans of the Serial pod

Topics