paint-brush
Hardware-aware Algorithm For Selective SSMsby@rendering
167 reads New Story

Hardware-aware Algorithm For Selective SSMs

tldt arrow

Too Long; Didn't Read

Linear Attention variants optimize autoregressive modeling with kernel approximations and efficient normalization. Long-context models promise scalability, but few prove performance gains—Selective SSMs aim to bridge the gap.

People Mentioned

Mention Thumbnail

Companies Mentioned

Mention Thumbnail
Mention Thumbnail
featured image - Hardware-aware Algorithm For Selective SSMs
Rendering Technology Breakthroughs HackerNoon profile picture
0-item

Authors:

(1) Albert Gu, Machine Learning Department, Carnegie Mellon University with Equal contribution (agu@cs.cmu.edu);

(2) Tri Dao, Department of Computer Science, Princeton University with Equal contribution (tri@tridao.me).

Abstract and 1. Introduction

2 State Space Models

3 Selective State Space Models and 3.1 Motivation: Selection as a Means of Compression

3.2 Improving SSMs with Selection

3.3 Efficient Implementation of Selective SSMs

3.4 A Simplifed SSM Architecture

3.5 Properties of Selection Mechanisms

3.6 Additional Model Details

4 Empirical Evaluation and 4.1 Synthetic Tasks

4.2 Language Modeling

4.3 DNA Modeling

4.4 Audio Modeling and Generation

4.5 Speed and Memory Benchmarks

4.6 Model Ablations

5 Discussion

6 Conclusion, Acknowledgments and References

A Discussion: Selection Mechanism

B Related Work and B.1 S4 Variants and Derivatives

B.2 SSM Architectures

B.3 Relationship to RNNs

B.4 Linear Attention and B.5 Long Context Models

C Mechanics of Selective SSMs

D Hardware-aware Algorithm For Selective SSMs

E Experimental Details and Additional Results and E.1 Synthetic Tasks

E.2 Language Modeling

E.3 DNA Modeling

E.4 Audio Details

E.5 Efficiency Benchmark

D Hardware-aware Algorithm For Selective SSMs

Speed. On modern hardware accelerators (GPUs) most operations (except matrix multiply) are bounded by memory-bandwidth (Dao, Fu, Ermon, et al. 2022; Ivanov et al. 2021; Williams, Waterman, and Patterson 2009). This the case with our scan operation, and we use kernel fusion to reduce the amount of memory IOs, leading to significant speedup compared to a standard implementation.

For sequence length L too long where we cannot fit the sequence in SRAM (which is much smaller than HBM), we split the sequences into chunks and perform the fused scan on each chunk. As long as we have the intermediate scan states, we can continue the scan with the next chunk.


Memory. We describe how we use the classical technique of recomputation to reduce the total amount of memory required to train selective SSM layers.

Beyond optimizing for the memory requirement of just the scan operation, we also use recomputation to optimize the memory requirement of the entire selective SSM block (input projection, convolution, activation, scan, output projection). In particular, we do not save intermediate activations that take a lot of memory but are fast to recompute (e.g. output of activation function or short convolution). As a result, the selective SSM layer has the same memory requirement as an optimized Transformer implementation with FlashAttention. In particular, each attention layer (FlashAttention) stores around 12 bytes of activations per token, an each MLP layer stores around 20 bytes of activations per token, for a total of 32 bytes ((assuming mixed-precision training in FP16 or BF16)). Each selective SSM stores around 16 bytes of activations per token. Hence two layers of selective SSMs have around the same activation memory as an attention layer and an MLP layer


E Experimental Details and Additional Results

E.1 Synthetic Tasks

Selective Copying. Our setting is on sequences of length 4096, with a vocab size of 16 possible tokens (including the white “noise” token from Figure 2) and requiring models to memorize 16 “data” tokens. We use 2 layer models with a model dimension of D = 64.


Models are trained for 400K steps at a constant learning rate of 0.0001 with a batch size of 64.


Induction Heads. Training consists of randomly generating data every step, with a batch size of 8. We choose an “epoch” size of 8192 steps, and track the accuracy on fixed validation sets (also randomly generated) of each target sequence length. For the MHA-Abs and Mamba models, results are reported after the 25th epoch (8192 × 25 = 204800 steps). For the MHA-RoPE and MHA-xPos models, results are reported after the 50th epoch (8192 × 50 = 409600 steps). For the LTI H3 and Hyena models, results are reported after the 10th epoch (81920 steps) because they had converged by then and failed to improve further.


Table 12: (Scaling Law Model Sizes.) Our model sizes and hyperparameters for scaling experiments. (Model dimension and number of heads applies only to Transformer models.)




This paper is available on arxiv under CC BY 4.0 DEED license.