paint-brush
Understanding Latency Trade-offs in Multi-Query vs. Multi-Head AI Modelsby@batching

Understanding Latency Trade-offs in Multi-Query vs. Multi-Head AI Models

by BatchingFebruary 24th, 2025
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Multi-query models tend to have higher latency in some scenarios but outperform multi-head models in high-step inference. Bifurcated attention helps multi-head models remain competitive, reducing memory use and improving efficiency in batch sampling.

Companies Mentioned

Mention Thumbnail
Mention Thumbnail
featured image - Understanding Latency Trade-offs in Multi-Query vs. Multi-Head AI Models
Batching HackerNoon profile picture
0-item

Authors:

(1) Ben Athiwaratkun, AWS AI Labs;

(2) Sujan Kumar Gonugondla, AWS AI Labs;

(3) Sanjay Krishna Gouda, AWS AI Labs;

(4) Haifeng Qian, AWS AI Labs;

(5) Sanjay Krishna Gouda, AWS AI Labs;

(6) Hantian Ding, AWS AI Labs;

(7) Qing Sun, AWS AI Labs;

(8) Jun Wang, AWS AI Labs;

(9) Jiacheng Guo, AWS AI Labs;

(10 Liangfu Chen, AWS AI Labs;

(11) Parminder Bhatia, GE HealthCare (work done at AWS);

(12) Ramesh Nallapati, Amazon AGI (work done at AWS);

(13) Sudipta Sengupta, AWS AI Labs;

(14) Bing Xiang, Goldman Sachs (work done at AWS).

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References


A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

5.2. Latencies of Capabilities-Equivalent Models

As detailed in Section 5.1, we’ve observed that an increase in the multi-query model’s size is required for it to match the performance of a multi-head model. In this section, we focus on examining the latency trade-offs across diverse scenarios with both multi-query and multi-head models of similar performance capabilities. For these latency experiments, we utilize two models, each with an approximate size of 1 billion: a multi-head model and a multi-query model (detailed information can be found in C.3). The multi-query model chosen for these studies is larger by a multiplicative factor F, where F = 1.1.


Overall, there is some overhead cost of using multi-query attention due to the larger size (see Figure 4 and Appendix D.3.1 and D.3.2 for analysis). That is, context encoding latency of the multi-query model will be slightly larger, as well as the low-context and low-batch incremental decoding scenario. However, multi-query can have significantly lower latency compared to multi-head in the scenario with high number of decoding steps which makes the incremental decoding phase being latency-dominating, and high context or batch size which heavily impacts the memory IO of incremental decoding. We outline three different inference scenarios below.


Figure 4: High-level latency comparison between an MH model and a larger MQ model with comparable capabilities. Overall, there’s an overhead cost for the initial context encoding latency due the additional compute with the larger MQ model size. For low context and batch size, the per step latency of MQ is also slightly higher to start due to the memory IO required for larger model size, but does not change much as context length m or batch size b grow, as supposed to the multi-head case where the per step latency can grow more rapidly with respect to m and b.


5.2.1. SINGLE CONTEXT SCENARIO


In the single batch inference scenario, the multi-query/- group attention can achieve lower latency when the context length and the number of generated tokens are high, as demonstrated in Figure 5. Different implementations that are more efficient in loading KV cache (such as lower-level kernel that can avoid duplicated IO) can cause the overall curves of MH to be flatter. However, the overall trend still remains where given sufficiently high context m, MQ will begin to be faster than MH.


5.2.2. SINGLE-CONTEXT BATCH SAMPLING


In this scenario, we are given a single context and generates multiple completions based on temperature sampling. In this case, the context encoding is independent of the batch size b since it is performed on the single context and broadcasted for other batch indices (Figure 1). In contrast to the batch inference scenario, this is a more practical online inference scenario since we are not bottlenecked by the context encoding step. Our proposed context-aware bifurcated attention is exactly applicable for such scenario where in


Figure 5: Incremental decoding (per step) latency and the context encoding latency, as a function of input context length. In this plot, we compare an multi-head model and an multi-query model of comparable capabilities, whose size is slightly larger. (Leftmost: Per-step incremental decoding latency) For low context length such as m < 2500, due to the larger size of the MQ model, the inference latency is higher. However, the growth with respect to context length of the MQ model is much lower (almost flat), resulting in lower per step latency when the context length is high. (Second: Context encoding latency) The context encoding latency depends on the FLOPs where the MH and MQ are quite similar. Note that the MQ model is slightly larger, and therefore corresponds to a steeper curve. (Third, Fourth): Total latency for 15 or 256 generated steps The two plots illustrates the total latency, which is the sum of context encoding and the the number of steps times incremental decoding latency. The benefits of MQ model becomes clear in the case of high decoding steps (256) whereas in the case of 15 generated tokens, the total latency of MQ can still be slightly higher than MH.


Figure 6: Context-aware bifurcated attention with multi-head attention (a) and multi-query attention (b). The bifurcated attention loads the KV cache in a context-aware manner, resulting in significantly lower latency for sampling under high batch sizes. For instance, in the case of multi-head attention with batch size 128 and context length 10, 000, bifurcated attention results in ≈ 4× lower the incremental decoding latency. Additionally, growth with respect to context length is relatively flat with bifurcated attention. With multi-query attention, bifurcated attention permits us to use batch sizes as high as 256 or 512 with lower latency than in the multi-head scenario.


this section we demonstrate the results in conjunction with both multi-head and multi-query.


Multi-head benefits significantly from bifurcated attention Figure 6a demonstrates the per step latency results for a multi-head model. For instance, with batch size 8, the per step latency without bifurcated attention grows rapidly with context length, from ≈ 10 ms to ≈ 100 ms at context length 10000. However, with bifurcated attention, the latency remains relatively flat with respect to context length. In practice, bifurcated attention also reduces memory consumption at high batch size and context lengths without encountering out-of-memory error as early as without bifurcated attention.


Bifurcated attention + multi-head rivals multi-query Figure 7 shows the comparison between MH and MQ with and without bifurcated attention. Without bifurcated attention, MQ is clearly much more inference efficient. However, with bifurcated attention, MQ and MH under moderate batch size scenarios (up to 64) seems comparable, where multi-head is even has lower latency. The results indicate that, given an existing MH model, we can support batch sampling scenarios using bifurcated attention without the need of a multi-query model (which requires training a new model, or at least continuous training) (Ainslie et al., 2023). With a more inference-intensive scenarios, including batch inference scenario where the bifurcated attention is not applicable, switching to multi-query can be worth the effort.


Figure 7: Latency comparison between multi-head and a larger multi-query model of equal capabilities. Without bifurcated attention, MQ is clearly much more inference efficient. However, with bifurcated attention, MH can have better latency than MQ in moderate scenario (up to batch size 64 in this case) where MQ can handle more extreme scenarios better than MH.


Figure 8: Bifurcated attention improves accuracy by enabling more generated samples over a fixed latency budget, applicable for both multi-head attention (CodeGen) and multi-query attention (StarCoder). Given the n samples, pass@n reflects the execution pass rate of the best sample among n, shown in (a) and (c). Filtering n samples with mean log probability ranking yields a subset of best three samples, reflected by pass@top3 in (b) and (d). The increased number of samples within the same latency budget results in increased performance via either pass@n or pass@top-k.


Bifurcated attention with multi-query enables more extreme batch size and context lengths Multi-query has overall h times lower memory IO and can already reduce latency for some inference scenarios. With bifurcated attention, the supported context lengths and batch sizes can become much more extreme, as demonstrated in Figure 6b.




This paper is available on arxiv under CC BY 4.0 DEED license.