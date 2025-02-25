Authors: (1) Ben Athiwaratkun, AWS AI Labs; (2) Sujan Kumar Gonugondla, AWS AI Labs; (3) Sanjay Krishna Gouda, AWS AI Labs; (4) Haifeng Qian, AWS AI Labs; (5) Sanjay Krishna Gouda, AWS AI Labs; (6) Hantian Ding, AWS AI Labs; (7) Qing Sun, AWS AI Labs; (8) Jun Wang, AWS AI Labs; (9) Jiacheng Guo, AWS AI Labs; (10 Liangfu Chen, AWS AI Labs; (11) Parminder Bhatia, GE HealthCare (work done at AWS); (12) Ramesh Nallapati, Amazon AGI (work done at AWS); (13) Sudipta Sengupta, AWS AI Labs; (14) Bing Xiang, Goldman Sachs (work done at AWS).

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References





A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

C.1. Model Training Details

We trained multiple models with varying sizes, ranging from 125 million parameters to 13 billion parameters, using code data with a context size of 2048 and adjusting the per-GPU batch size and total number of steps according to the model size. For model training we used multiple p4 instances each equipped with 8 40GB Nvidia A100 GPUs per instance.





For our largest model family, the 13 billion parameter model, we used a global batch size of 1024, which approximately translates to 2 million tokens per batch. The settings for each model within each model-size family were kept consistent. The remaining training hyperparameters are summarized in the following table 1.





C.2. Model Configurations

For each model size we train three models with attention variations; multi head where g = h, multi group where 1 < g < h and multi query where g = 1. Additionally, for 672m and 2.8b models we train a multi group model variant where the fanout in feed forward layer is decreased from 4×d to 2×d. Each model variant yields different number of total parameters therefore we group these models into family of model sizes. The detailed architectural choices for each of the model family is found in the table 2.





C.3. Model Details of 1B Latency Experiment

In Section 5.2.2, we use candidate models of sizes roughly 1B to study the effect of bifurcated attention. We outline the hyperparameters of such models below.

C.4. Ablation Studies: 2d Intermediate Feature Dimension

One can also argue that different g results in different balance of the number of parameters in the feedforward versus the attention components. We performed an ablation study where we reduce the typical intermediate feature size of 4d to 2d and train models for three model sizes (which we will refer to as the 2d experiment). The ablation study reveals that the scaling laws curves for the 2d experiment crosses the usual 4d curves, which implies that the reduced size of the attention component alone compared to feedforward does not provide a consistent explanation of model capabilities. This can be seen from Figure 9.

C.5. Inference Setup

We use Nvidia A100 GPUs for inference hardware (Choquette et al., 2021). We perform latency studies using Deepspeed inference (Rasley et al., 2020) on top of Huggingface transformers (Wolf et al., 2019), where we wrote custom code to handle the generalize multi-group attention as well as bifurcated attention. Future work includes extending the implementation to FasterTransformer (NVIDIA).





