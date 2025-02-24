Authors: (1) Ben Athiwaratkun, AWS AI Labs; (2) Sujan Kumar Gonugondla, AWS AI Labs; (3) Sanjay Krishna Gouda, AWS AI Labs; (4) Haifeng Qian, AWS AI Labs; (5) Sanjay Krishna Gouda, AWS AI Labs; (6) Hantian Ding, AWS AI Labs; (7) Qing Sun, AWS AI Labs; (8) Jun Wang, AWS AI Labs; (9) Jiacheng Guo, AWS AI Labs; (10 Liangfu Chen, AWS AI Labs; (11) Parminder Bhatia, GE HealthCare (work done at AWS); (12) Ramesh Nallapati, Amazon AGI (work done at AWS); (13) Sudipta Sengupta, AWS AI Labs; (14) Bing Xiang, Goldman Sachs (work done at AWS).

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References





A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

2. Related Work

In the literature, there are multiple avenues to improve inference latency and/or latency. Quantization reduces memory usage by using low-bitwidth representations such as int8, int4, and fp8 (Wei et al., 2023; Yao et al., 2022; Dettmers et al., 2022; Frantar et al., 2022; Kuzmin et al., 2022; Xiao et al., 2022). Quantization when applied only to model parameters offer diminishing results as with longer sequence lengths and large batch sizes where memory access and compute associated with dot-product attention dominates the overall inference latency.

Sparse attention (Beltagy et al., 2020; Child et al., 2019; Zaheer et al., 2020) has been extensively studied as a way to reduce the complexity of attention for longer contexts and faster inference. Pope et al. (2022) investigates generative inference efficiency of large language models by using multi-dimensional partitioning techniques optimized for TPUs (collective einsum) to achieve a Pareto frontier on latency and model FLOPs utilization. The paper also shows that multi-query attention allows scaling up to 32x larger context length with an emphasis on the efficiency under high batch size. Paged attention (Kwon et al., 2023) enhances memory management of the KV cache by dividing it into blocks and employing a block table for mapping purposes. This approach effectively accommodates dynamic workload shifts and reduces memory storage requirements through the sharing of the prompt’s KV cache across multiple output sequences. However, this does not reduce the memory reads of KV cache.





Speculative decoding, and its variants uses a smaller draft model to propose multiple sequential tokens, which are processed in parallel by the main model to accept or reject such tokens (Chen et al., 2023; Leviathan et al., 2022; Li et al., 2024; Cai et al., 2024; Fu et al., 2023). The key idea is to enable decoding of multiple tokens at every step, thereby amortizing the memory IO usages of the main model. However, the latency of decoding will be still dominated by KV cache I/O bandwidth at large context sizes, where bifurcated attention can enhance the decoding speed further. In short, incremental decoding focuses on lowering the amortized memory IO of model loading while multi-query and bifurcated attention lowers the memory IO of KV cache.

3. Background

3.1. Notation

We use the following notation throughout the paper.













3.2. Language Model Inference

There are many inference scenarios for language model, including batch inference and single-context batch sampling (Figure 1). Batch inference refers to the case where we process multiple inputs together in a batch, and generate subsequent tokens for each batch index independently. In the case where the batch size is 1, this reduces to the single-context inference. Another scenario is the single-context batch sampling where we generates multiple sequences based on a single context, where difference between the batch inference case is that the prefill only needs to be done for a single context to obtain the KV cache, then broadcasted to other batch indices.





Figure 1 also illustrates the two phases of language model inference: (a) the context encoding or prefilling and (b) the incremental decoding. The context encoding refers to a single forward pass that computes the key and value tensors for all token positions in the context. Once the key and value tensors are computed, we cache these key and value tensors to be used for the attention mechanism during the incremental decoding phase, which sequentially generates one token at a time[2].

















During the context encoding phase, the number of floating point operations relative to the memory input/output (IO) operations is high, corresponding to the compute-bound regime where the latency is influenced by the FLOPs. However, during incremental decoding where we perform attention on a single query token, this falls into a memory-bound regime where the number of computation per memory access is roughly 1-to-1 (see Appendix D.1 for details). The memory IO refers to the read and write operations from the high bandwidth memory (HBM) (Jia et al., 2018) to the fast on-chip SRAM where the actual computation happens. The memory IO of the incremental decoding itself consists of two components: (1) the model parameter loading and (2) KV cache loading. Component (1) is constant regardless of the context length m or batch size b where component (2) depends on both m and b and dominate the overall memory IO if m or b are high, which can become a significant bottleneck for inference. Our work primarily focuses on reducing component (2).





This paper is available on arxiv under CC BY 4.0 DEED license.



