Authors:
(1) Ben Athiwaratkun, AWS AI Labs;
(2) Sujan Kumar Gonugondla, AWS AI Labs;
(3) Sanjay Krishna Gouda, AWS AI Labs;
(4) Haifeng Qian, AWS AI Labs;
(5) Sanjay Krishna Gouda, AWS AI Labs;
(6) Hantian Ding, AWS AI Labs;
(7) Qing Sun, AWS AI Labs;
(8) Jun Wang, AWS AI Labs;
(9) Jiacheng Guo, AWS AI Labs;
(10 Liangfu Chen, AWS AI Labs;
(11) Parminder Bhatia, GE HealthCare (work done at AWS);
(12) Ramesh Nallapati, Amazon AGI (work done at AWS);
(13) Sudipta Sengupta, AWS AI Labs;
(14) Bing Xiang, Goldman Sachs (work done at AWS).
Table of Links
3.1. Notation and 3.2. Language Model Inference
3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention
4. Context-Aware Bifurcated Attention and 4.1. Motivation
4.2. Formulation and 4.3. Memory IO Complexity
5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention
5.2. Latencies of Capabilities-Equivalent Models
D. Multi-Group Attention Family
E. Context-Aware Bifurcated Attention
F. Applications: Additional Results
G. Compatibility with Speculative Decoding and Fast Decoding techniques
5.2. Latencies of Capabilities-Equivalent Models
As detailed in Section 5.1, we’ve observed that an increase in the multi-query model’s size is required for it to match the performance of a multi-head model. In this section, we focus on examining the latency trade-offs across diverse scenarios with both multi-query and multi-head models of similar performance capabilities. For these latency experiments, we utilize two models, each with an approximate size of 1 billion: a multi-head model and a multi-query model (detailed information can be found in C.3). The multi-query model chosen for these studies is larger by a multiplicative factor F, where F = 1.1.
Overall, there is some overhead cost of using multi-query attention due to the larger size (see Figure 4 and Appendix D.3.1 and D.3.2 for analysis). That is, context encoding latency of the multi-query model will be slightly larger, as well as the low-context and low-batch incremental decoding scenario. However, multi-query can have significantly lower latency compared to multi-head in the scenario with high number of decoding steps which makes the incremental decoding phase being latency-dominating, and high context or batch size which heavily impacts the memory IO of incremental decoding. We outline three different inference scenarios below.
5.2.1. SINGLE CONTEXT SCENARIO
In the single batch inference scenario, the multi-query/- group attention can achieve lower latency when the context length and the number of generated tokens are high, as demonstrated in Figure 5. Different implementations that are more efficient in loading KV cache (such as lower-level kernel that can avoid duplicated IO) can cause the overall curves of MH to be flatter. However, the overall trend still remains where given sufficiently high context m, MQ will begin to be faster than MH.
5.2.2. SINGLE-CONTEXT BATCH SAMPLING
In this scenario, we are given a single context and generates multiple completions based on temperature sampling. In this case, the context encoding is independent of the batch size b since it is performed on the single context and broadcasted for other batch indices (Figure 1). In contrast to the batch inference scenario, this is a more practical online inference scenario since we are not bottlenecked by the context encoding step. Our proposed context-aware bifurcated attention is exactly applicable for such scenario where in
this section we demonstrate the results in conjunction with both multi-head and multi-query.
Multi-head benefits significantly from bifurcated attention Figure 6a demonstrates the per step latency results for a multi-head model. For instance, with batch size 8, the per step latency without bifurcated attention grows rapidly with context length, from ≈ 10 ms to ≈ 100 ms at context length 10000. However, with bifurcated attention, the latency remains relatively flat with respect to context length. In practice, bifurcated attention also reduces memory consumption at high batch size and context lengths without encountering out-of-memory error as early as without bifurcated attention.
Bifurcated attention + multi-head rivals multi-query Figure 7 shows the comparison between MH and MQ with and without bifurcated attention. Without bifurcated attention, MQ is clearly much more inference efficient. However, with bifurcated attention, MQ and MH under moderate batch size scenarios (up to 64) seems comparable, where multi-head is even has lower latency. The results indicate that, given an existing MH model, we can support batch sampling scenarios using bifurcated attention without the need of a multi-query model (which requires training a new model, or at least continuous training) (Ainslie et al., 2023). With a more inference-intensive scenarios, including batch inference scenario where the bifurcated attention is not applicable, switching to multi-query can be worth the effort.
Bifurcated attention with multi-query enables more extreme batch size and context lengths Multi-query has overall h times lower memory IO and can already reduce latency for some inference scenarios. With bifurcated attention, the supported context lengths and batch sizes can become much more extreme, as demonstrated in Figure 6b.
This paper is available on arxiv under CC BY 4.0 DEED license.