FlashDecoding++: Faster Large Language Model Inference on GPUs: Heuristic Dataflow with Hardwareby@textmodels
158 reads

FlashDecoding++: Faster Large Language Model Inference on GPUs: Heuristic Dataflow with Hardware

tldt arrow

Too Long; Didn't Read

Due to the versatility of optimizations in FlashDecoding++, it can achieve up to 4.86× and 2.18× speedup on both NVIDIA and AMD GPUs compared to Hugging Face.
featured image - FlashDecoding++: Faster Large Language Model Inference on GPUs: Heuristic Dataflow with Hardware
Writings, Papers and Blogs on Text Models HackerNoon profile picture

This paper is available on arxiv under CC BY-NC-SA 4.0 DEED license.


(1) Ke Hong, Tsinghua University & Infinigence-AI;

(2) Guohao Dai, Shanghai Jiao Tong University & Infinigence-AI;

(3) Jiaming Xu, Shanghai Jiao Tong University & Infinigence-AI;

(4) Qiuli Mao, Tsinghua University & Infinigence-AI;

(5) Xiuhong Li, Peking University;

(6) Jun Liu, Shanghai Jiao Tong University & Infinigence-AI;

(7) Kangdi Chen, Infinigence-AI;

(8) Yuhan Dong, Tsinghua University;

(9) Yu Wang, Tsinghua University.

5 Heuristic Dataflow with Hardware Resource Adaption

Motivation. Although FlashDecoding++ optimizes the flat GEMM operation in Section 4, it does not cover all operations (even only for GEMMs) in the LLM inference. As mentioned in Figure 2, the shapes of GEMMs in different operations and two phases vary. Thus, the GEMM workload in the LLM inference can be GEMV (batch size=1 for the decode phase), flat GEMM (small batch size for the decode phase and short sequence length for the prefill phase) and conventional GEMM (large batch size or long sequence length for the prefill phase). In order to leverage the powerful computational ability of Tensor Core, current frameworks like FasterTransformer [33] and DeepSpeed [9] tend to utilize the highly optimized GEMM implementation from cuBLAS [24] to deal with different workloads. However, the Tensor Core implementation fails with the GEMV workload. The GEMV workload can be optimized by utilizing CUDA Core in previous designs like FastGEMV [34]. For a Llama2-7B linear layer in the decode phase, the Tensor Core implementation from cuBLAS only achieves 82.15% of the performance of CUDA Core implementation using FastGEMV on an NVIDIA A100 GPU. On the other hand, using CUDA Core to do the projection on a batchsize=4 decoding input only achieves 49.75% performance compared with the Tensor Core implementation. Thus, in order to approach the optimal computation performance, a heuristic dataflow is supposed to be exploited in for different workloads.

Challenge. Although a heuristic dataflow potentially exists in the implementation of different linear workloads, it is challenging to build the mapping from a certain workload to an optimal implementation. In the scenario of LLM inference, there are various factors that influence the implementation performance of linear workloads: (a) Input dynamics. The variety of the batch size and the input sequence length brings dynamic workloads. (b) Model diversity. The linear workload varies with different model structures and sizes. (c) GPU capacities. The relative performance between implementations changes with GPU characteristics, such as memory bandwidth, cache size, and computational ability. (d) Engineering effects. The engineering effort also highly impacts the kernel performance. All these influential factors build a large search space, making it non-trivial to generate an effective mapping between the linear workload and the corresponding optimal implementation.

Analysis and Insights. Although all influential factors form a large search space, the homogeneity of different layers in LLM significantly reduces the search space for operator optimization. Figure 2 shows four linear GEMV/GEMM operations in the prefill phase and the decode phase, i.e., K, Q, V projection, O projection, and two feedforward operations. Each GEMV/GEMM operation can be can be abstracted as a multiplication between an (M × K)-shaped matrix and a (K × N)-shaped matrix. Our key insight is, there are only four [K, N] shapes for a certain LLM. Moreover, M is only related to the input sequence length and the batch size for the prefill phase, and the batch size for the decode phase. Figure 9(a) shows limited shapes of GEMV/GEMM operations in the LLM inference.

Approach: Decision flow for inflection points. Because only four [K, N] shapes exist for a certain LLM, we use three types of implementations for GEMV/GEMM operations when M varies: FastGEMV for the GEMV and flat GEMM operations (ImplA), our flat GEMM optimization in Section 4 (ImplB), and the CUTLASS [25] libraries optimized for the conventional GEMM (ImplC). Thus, it is important to decide whether applying ImplA or ImplB for a small M, and ImplB or ImplC for a large M. Figure 9(b) shows the decision flow. FlashDecoding++ profiles the performance of ImplA and ImplB for a certain M, and increases M to find an inflection point M1 where the performance of ImplB is

Figure 9: Heuristic dataflow with hardware resource adaption in FlashDecoding++. (a) Only four [N, K] shapesexist for a certain LLM. (b) The decision flow. We traverse all [N, K] selections and profile the performance of three

better than ImplA. Another inflection point M2 is found similarly where the performance of ImplC is better than ImplB. Note that each [N, K] gets its individual M1 and M2.

Approach: Heuristic dataflow. For the runtime LLM inference, FlashDecoding++ adopts ImplA using CUDA Core when M < M1, and ImplB/ImplC using Tensor Core when M1 ≤ M < M2/M2 ≤ M. Note that the decision flow are executed offline, it does not affect the performance of runtime LLM inference.

Example. Figure 9(c) shows an example of applying the heuristic dataflow for the Llama2-7B model. Four [N, K] shapes are [12288, 4096] for K, Q, V projection, [4096, 4096] for O projection, [11008, 4096] and [4096, 11008] for FFN. For each [N, K], the inflection points are found based on the decision flow in Figure 9(c). Then, a lookup table is formed, and each GEMV/GEMM operation is executed according to corresponding implementations during runtime. In this example, FastGEMV is adopted for the K, Q, V projection when batch size=1 (M = 1) for the decode phase, and our flat GEMM optimization is applied when batch size=1/input sequence length=8 for FFN1 (M = 8).