Authors:
(1) Siqi Kou, Shanghai Jiao Tong University and with Equal contribution;
(2) Lanxiang Hu, University of California, San Diego and with Equal contribution;
(3) Zhezhi He, Shanghai Jiao Tong University;
(4) Zhijie Deng, Shanghai Jiao Tong University;
(5) Hao Zhang, University of California, San Diego.
Table of Links
3. Methodology and 3.1. Preliminary: Jacobi Decoding
3.2. Consistency Large Language Models (CLLMs)
3.3. Acceleration Mechanisms in CLLMs
4. Experiments
4.2. Acceleration Mechanisms in CLLMs
4.4. Limitations and Discussion
5. Conclusion, Impact Statement, and References
A. Illustration of Consistency Loss Learning Objectives
B. Comparison with Baseline Algorithms
C. Pesudo Code for Jacobi Decoding with KV Cache
3.2. Consistency Large Language Models (CLLMs)
Despite the promise, the speedup effect of Jacobi decoding for vanilla LLMs is minimal in practice (Santilli et al., 2023; Fu et al., 2024). The reason is that AR-trained LLMs can usually generate only one correct token in each Jacobi iteration as such models can rarely yield a correct token when there are incorrect preceding tokens. To address this, we propose to adapt pre-trained LLMs to consistently map any point y on the Jacobi trajectory J to the fixed point y∗. Surprisingly, such an objective is analogous to that of consistency models (Song et al., 2023; Song & Dhariwal, 2023), a leading acceleration approach for diffusion models (Ho et al., 2020; Song et al., 2021b).
This section first delineates our data preparation procedure for tuning CLLM and then elaborates on the training procedure of CLLM. Lastly, we discuss some possible sources of the reason for CLLMs’ acceleration.
3.2.1. JACOBI TRAJECTORY COLLECTION
Let p denote the target LLM we aim to adapt. Let qθ(·|x) denote the CLLM with parameters θ initialized with those of p. To realize the aforementioned adaptation, we collect a set of Jacobi trajectories by running the Jacobi decoding algorithm with the target LLM p on prompts from a certain domain of interest, forming an original training set D. We summarize the algorithm for dataset generation in Algorithm 1. Note that to generate a lengthy response l of N (N ≫ n) tokens, we can sequentially perform Jacobi decoding for every truncation of n tokens to avoid slow model evaluation on lengthy input. Consequently, l amounts to the concatenation of a set of consecutive fixed points.
Data augmentation. In a typical Jacobi iteration process, the correct tokens often appear one after another, and ntoken sequences usually exhibit a “correct, correct, wrong, wrong, wrong” pattern. In comparison, patterns like “correct, correct, wrong, correct, wrong” can be rare. To enhance the learning and generalization capabilities of CLLMs, we augment the dataset D by randomly correcting erroneously predicted tokens within the samples.
Data post-processing. Since the target LLM itself can make errors for some prompts, it often leads to low-quality generations in the Jacobi trajectories. We find training a CLLM with n-token sequences with token-level (Holtzman et al., 2019) or sentence-level repetitions (Polisensk ˇ a et al. ´ , 2015) often results in to repetitive content generation and noticeably degrades performance. Recognizing the significance of high-quality datasets for training LLMs (Zhou et al., 2023a), we perform post-processing to eliminate the low-quality samples from our training dataset D based on a rule-based detector.
3.2.2. TRAINING
We jointly optimize two losses for tuning CLLMs, one guaranteeing the prediction of multiple tokens at once and the other avoiding the CLLM from deviating from the target LLM so as to maintain generation quality.
Consistency Loss. For a prompt x with the Jacobi trajectory J, let y and y ∗ denote a random state on the trajectory and the fixed point respectively. We can directly push CLLM to output y ∗ with y as the input by minimizing the following loss:
where θ − = stopgrad(θ) and we abuse notations to represent uniform sampling from the dataset. D(·||·) denotes the distance between two distributions, with forward KL, reverse KL, and their mixture (i.e., the Jensen-Shannon divergence) as popular examples (Agarwal et al., 2023). We primarily experiment with the forward KL.
This term contributes to maintaining generation quality substantially (see Table 6).
Consequently, the total loss for training a CLLM is:
The training procedure is detailed in Algorithm 2.
This paper is available on arxiv under CC0 1.0 Universal license.