paint-brush
How to Speed Up Your AI Models—Without Frying Your Memoryby@batching
140 reads

How to Speed Up Your AI Models—Without Frying Your Memory

by BatchingFebruary 24th, 2025
Read on Terminal Reader
Read this story w/o Javascript

Too Long; Didn't Read

Bifurcated attention is a novel method for optimizing large language model inference by reducing memory IO costs. It divides the attention mechanism into two GEMM operations—handling prefill KV cache separately from decoding. This approach maintains computational efficiency while lowering latency, enabling larger batch sizes and real-time AI applications.

Companies Mentioned

Mention Thumbnail
Mention Thumbnail
featured image - How to Speed Up Your AI Models—Without Frying Your Memory
Batching HackerNoon profile picture
0-item

Authors:

(1) Ben Athiwaratkun, AWS AI Labs;

(2) Sujan Kumar Gonugondla, AWS AI Labs;

(3) Sanjay Krishna Gouda, AWS AI Labs;

(4) Haifeng Qian, AWS AI Labs;

(5) Sanjay Krishna Gouda, AWS AI Labs;

(6) Hantian Ding, AWS AI Labs;

(7) Qing Sun, AWS AI Labs;

(8) Jun Wang, AWS AI Labs;

(9) Jiacheng Guo, AWS AI Labs;

(10 Liangfu Chen, AWS AI Labs;

(11) Parminder Bhatia, GE HealthCare (work done at AWS);

(12) Ramesh Nallapati, Amazon AGI (work done at AWS);

(13) Sudipta Sengupta, AWS AI Labs;

(14) Bing Xiang, Goldman Sachs (work done at AWS).

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References


A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

Abstract

In our study, we present bifurcated attention, a method developed for language model inference in single-context batch sampling contexts. This approach aims to reduce redundant memory IO costs, a significant factor in latency for high batch sizes and long context lengths. Bifurcated attention achieves this by dividing the attention mechanism during incremental decoding into two distinct GEMM operations, focusing on the KV cache from prefill and the decoding process. This method ensures precise computation and maintains the usual computational load (FLOPs) of standard attention mechanisms, but with reduced memory IO. Bifurcated attention is also compatible with multi-query attention mechanism known for reduced memory IO for KV cache, further enabling higher batch size and context length. The resulting efficiency leads to lower latency, improving suitability for real-time applications, e.g., enabling massively-parallel answer generation without substantially increasing latency, enhancing performance when integrated with postprocessing techniques such as reranking.

1. Introduction

The advent of large language models (LLMs) has ushered in a new era of machine learning, exhibiting remarkable performance on a wide array of tasks (Brown et al., 2020; OpenAI, 2023; Chowdhery et al., 2022; Touvron et al., 2023; Chen et al., 2021; Hoffmann et al., 2022; Li et al., 2022; Microsoft; Amazon, 2022; Nijkamp et al., 2023). Despite their impressive capabilities, the deployment of these large-scale models in practical applications poses significant challenges, particularly in terms of inference latency and efficiency. Enhancing these aspects is critical, as they directly influence the computational resources required to generate predictions and enable the practical implementation of these advanced models across various industries.


A particularly demanding inference scenario is single-context batch sampling, where the goal is to generate multiple completions from a single context. This task is commonly encountered in numerous applications such as code-editing IDE tools that provide multiple recommendations, or in cases where ranking among many generations is needed for optimal performance (via ranking metrics like mean log probability, majority voting, etc). The incremental decoding of such sampling scenario is memory IO intensive, which becomes a latency bottleneck for high batches and context lengths.


In this study, we investigate two compatible strategies to address the memory IO challenges in tranformers inference: (1) an investigation of multi-query and its trade-offs, and (2) a novel technique called context-aware bifurcated attention.


Our investigation begins with an analysis of the generalized multi-query attention (Ainslie et al., 2023), which includes multi-query (Shazeer, 2019), as well as the established multi-head attention mechanism (Vaswani et al., 2017) for performance and latency trade-off. Our findings show smooth performance scaling with increasing model size for a fixed value of the number of groups g for generalized multi-query[1]. Lowering g results in an upward shift of the validation loss vs model size scaling curves. The consistent relationship between the cache compression, model size and validation loss allows us to trade-off inference efficiency with model size, i.e., enables us to select higher compression for use cases requiring high efficiency, while still matching the performance of multi-head attention by compensating with a larger model size.


Secondly, we introduce context-aware bifurcated attention, a technique that bifurcates any attention in the generalized multi-query family into context and decoding components during incremental decoding. Such bifurcation involves the same number of FLOPs and yields identical results compared to the original attention, but can significantly reduces memory IO cost and thus latency in high batch and context length scenarios. This approach allows the generation of multiple real-time completions without incurring much additional latency costs, or enables much higher batch sizes leading to improved ranking performance. For instance, for CodeGen 16B multi-head model (Nijkamp et al., 2022) with 2k context length, we are able to increase the batch size to 128 with bifurcated attention, compared to batch size of only 5 without, resulting in the pass@k (Chen et al., 2021) increasing from 59.0% to 84.6%, or pass@top3 via mean log-p increasing from 55.2% to 58.1%.


This paper is available on arxiv under CC BY 4.0 DEED license.


[1] Lower values of attention groups g lead to higher compression of the key-value tensors, as in the multi-query case where g = 1, hence improving inference efficiency and latency due to reduced KV cache compared to the multi-head case where g = h, the number of query attention heads.