Understanding Multi-Group Attention in AI Models

Written by batching | Published 2025/02/26
Tech Story Tags: ai-code-generation | ai-inference | bifurcated-attention | memory-io-optimization | low-latency-ai | llm-batch-sampling | transformer-model-efficiency | multi-query-attention

TLDR Multi-group attention optimizes AI model efficiency by reducing memory IO costs. FLOPs remain proportional to parameters, ensuring scalability across architectures.via the TL;DR App

Authors:

(1) Ben Athiwaratkun, AWS AI Labs;

(2) Sujan Kumar Gonugondla, AWS AI Labs;

(3) Sanjay Krishna Gouda, AWS AI Labs;

(4) Haifeng Qian, AWS AI Labs;

(5) Sanjay Krishna Gouda, AWS AI Labs;

(6) Hantian Ding, AWS AI Labs;

(7) Qing Sun, AWS AI Labs;

(8) Jun Wang, AWS AI Labs;

(9) Jiacheng Guo, AWS AI Labs;

(10 Liangfu Chen, AWS AI Labs;

(11) Parminder Bhatia, GE HealthCare (work done at AWS);

(12) Ramesh Nallapati, Amazon AGI (work done at AWS);

(13) Sudipta Sengupta, AWS AI Labs;

(14) Bing Xiang, Goldman Sachs (work done at AWS).

Table of Links

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References

A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

D. Multi-Group Attention Family

D.1. Detailed Analysis on Memory Access

We show in Table 4 that the memory IO cost for ⟨q, K⟩ is dominated by the loading of K which costs bmhk in the case of multihead where g = h. This cost is particularly high due to the coupling of batch size b, context length m, and the entire hidden dimension d. Compared to the number of computations, which has complexity bmd, this attention module requires one memory IO per one tensor operation (memory-io bound). In contrast, other operations such as feedforw can be the main bottleneck for incremental decoding and our paper aims to tackle such problems.ard has much lower ratio of memory IO per compute (compute bound). These attention computation

D.2. Model FLOPs

The scaling laws by Kaplan et al. (2020) shows that the modelrelated FLOPs during the forward pass is 2N where N is the number of parameters (without the embeddings). We show that it holds for a general multi-group model as well. The only difference between the multi-group and the multi-head case is the projection PK and PV where they are of size dgk instead of dhk. Since this is a linear layer, the forward pass FLOPs for any input is still proportional such projection size. Therefore, it follows that for any multi-group attention, including multi-head, the forward FLOPs is 2N where N is the respective number of parameters.

D.3. Comparing Capabilities-Equivalent Models

This section outlines the analysis of latency change when we switch from an MH model to an MG model with F times the size.

D.3.1. CONTEXT ENCODING

D.3.2. INCREMENTAL DECODING

This paper is available on arxiv under CC BY 4.0 DEED license.


Written by batching | Batching converges tasks in a single go, maximizing productivity and minimizing overhead.
Published by HackerNoon on 2025/02/26