Exploration of model architecture optimizations for Large Language Model (LLM) inference, focusing on Group Query Attention (GQA) and Mixture of Experts (MoE) techniques.
Posts in this series:
Primer on Large Language Model (LLM) Inference Optimizations: 1. Background and Problem Formulation
Primer on Large Language Model (LLM) Inference Optimizations: 3. Model Architecture Optimizations (this post)
In our journey exploring LLM inference optimization, we’ve covered significant ground in the previous two posts. Our first post introduced the fundamentals of LLM inference, detailing the transformer architecture and its key components like attention mechanisms and K-V caching.
We learned about the two critical stages of inference: the prefill stage and the decoding stage, highlighting how these impact performance and resource utilization. The prefill stage has the complexity of O(L.d^2)
and the decoding stage has the complexity of O(n .d^2)
.
Here, L
is the sequence length and d
is the model dimension (ignoring the complexities of softmax and other operations for simplicity). We also covered metrics to evaluate LLM inference performance, such as time to first token (TTFT), decoding throughput (tokens per second), end-to-end latency, and maximum request rate (aka QPS). I highly recommend reading the first post to get a solid understanding of the challenges and complexities involved in LLM inference.
Our second post focused on hardware acceleration, introducing AI accelerators as specialized solutions for efficient LLM inference. We explored various types of accelerators (GPUs, TPUs, FPGAs/ASICs) and their key features like parallel processing capabilities, high-bandwidth memory, and support for low-precision arithmetic.
We also discussed different parallelism strategies - data, model (tensor and pipeline), and task parallelism - that enable efficient scaling of LLM workloads.
Building on this foundation, this post will explore model architecture optimizations that can significantly improve inference efficiency. We’ll focus on techniques that modify the transformer architecture to reduce computational complexity and memory requirements while maintaining model performance. Note that these optimizations, as the name suggests, require changes in the model architecture and hence, need to be implemented prior to training the model.
Group Query Attention (GQA) is a model architecture optimization that reduces memory and computational costs during inference while maintaining model quality. As we know, the standard transformer architecture uses Multi-Head Attention (MHA) to capture complex patterns in the data.
However, MHA can be memory-intensive, especially when dealing with large models and sequences. The K-V cache grows linearly with the number of heads, leading to significant memory requirements.
In standard Multi-Head Attention (MHA) (discussed in detail in post 1), each attention head has its own Query (Q), Key (K), and Value (V) matrices:
What if we use a single K and V head across all query heads? This is called Multi-Query Attention (MQA) [1]. As K and V computation is done only once, it can be more efficient than Multi-Head Attention which is shown below. Note that all query heads share the same K and V heads:
However, MQA can be limited in its ability to capture diverse attention patterns due to the shared K and V heads and may not perform well.
Grouped-Query Attention (GQA)[2] strikes a balance between MHA and MQA by sharing a single K and V head across multiple query heads. It reduces memory requirements and computational complexity while maintaining model quality. Note that two query heads are grouped together in the example below and use the same K and V heads within the group:
For a model with:
MHA Memory per Layer:
GQA Memory per Layer:
This represents a 4x reduction in KV cache memory requirements per layer.
As we can see, GQA offers significant memory savings compared to standard MHA, making it an attractive optimization for LLM inference. What makes GQA particularly attractive is that these memory savings come with minimal trade-offs:
Performance Preservation: While using fewer parameters than standard multi-head attention, GQA maintains nearly equivalent model quality. This is because the separate Query projections per head still allow the model to capture diverse attention patterns.
Implementation Advantages: Compared to its cousin Multi-Query Attention (MQA), GQA offers a more practical middle ground. It’s simpler to implement while providing better control over the efficiency-performance trade-off through adjustable group sizes.
Hardware Efficiency: GQA’s architecture naturally aligns with modern AI accelerators in several ways:
While GQA offers significant benefits, it’s essential to consider the following challenges:
Mixture of Experts (MoE)[3] represents one of the most innovative architectural optimizations in modern LLMs. Instead of having a single massive neural network, MoE divides the model into multiple specialized sub-networks or “experts,” each focusing on different aspects of language processing.
At its core, MoE transforms the traditional dense Feed-Forward Network (FFN) layer in transformers into a more dynamic and efficient system. It consists of three key components:
Expert Networks: The model contains multiple specialized FFN networks (experts), each potentially good at different tasks (e.g., reasoning, coding, or creative writing - note that these are just learned behaviors during training, not explicitly assigned). Each expert is smaller and more focused, making it easier to train and optimize.
Gating Network: A smart “traffic controller” that decides which experts to consult for each input token or segment. It uses the input features to determine the relevance of each expert for the current context. This gating mechanism is trainable and adapts to the input data, directing the token to the most relevant expert. [4]
Sparse Activation: Only a small number (typically 1-2) of experts are activated for each token. This sparse activation significantly reduces the computational and memory cost, as only a fraction of the model is used for each inference step.
MoE architecture showing the gating network directing input to relevant experts
MoE offers several compelling advantages for inference optimization:
Only 1-2 experts are activated per token instead of utilizing the entire network, leading to substantial savings in computational cost.
MoE can reduce FLOPs by up to 5-10x compared to a dense model with equivalent total capacity.
The reduced memory usage further helps in lowering latency during inference, making it highly efficient for serving large models.
Reduced resource usage leads to better throughput and lower latency, making it highly efficient for serving large models.
Different experts can implicitly specialize in various domains or tasks, enabling the model to handle diverse queries more effectively.
MoE helps mitigate issues like catastrophic forgetting in multi-task settings, as separate experts can learn distinct patterns without interfering with each other.
This specialization often results in improved quality for domain-specific queries.
MoE has proven its worth in production systems:
In this post, we explored two key model architecture optimizations - Grouped Query Attention (GQA) and Mixture of Experts (MoE) - that can significantly enhance the efficiency and performance of Large Language Model (LLM) inference. GQA offers a practical way to reduce memory and computational costs while maintaining model quality, making it an attractive optimization for LLMs.
On the other hand, MoE introduces a novel approach to model specialization and efficiency, enabling large models to scale efficiently and handle diverse tasks effectively.
In the next post, we will delve into system-level optimizations that complement these model-level techniques, providing a holistic view of LLM inference optimization strategies.