paint-brush
改进长短期记忆网络:面向下一代 AI 的 XLSTM经过@aibites
1,701 讀數
1,701 讀數

改进长短期记忆网络:面向下一代 AI 的 XLSTM

经过 Shrinivasan Sankar7m2024/05/23
Read on Terminal Reader

太長; 讀書

XLSTM 试图让那些被 Transformers 夺走荣耀的 LSTM 重返舞台。那么,XLSTM 会如约而至吗?还是只是炒作?让我们在本文中一探究竟。
featured image - 改进长短期记忆网络:面向下一代 AI 的 XLSTM
Shrinivasan Sankar HackerNoon profile picture
0-item

LSTM(长短期记忆网络)已经存在很长时间了。它们已被应用于相当多与序列相关的任务,例如文本生成和翻译,甚至生成图像字幕。


它们的缺点是无法并行化以利用强大的现代 GPU。这一限制为利用 GPU 进行大规模并行训练和推理的 Transformer 的出现铺平了道路。


如果我们现在尝试改进和并行化 LSTM,它们能成为构建下一代LLM的工具吗?


这正是论文“ XLSM——扩展长短期记忆网络”所回答的问题, XLSM 代表“扩展”长短期记忆。他们通过在架构中提出两个新模块,即 sLSTM 和 mLSTM 来实现这一点。

因此,让我们深入研究本文提出的 sLSTM 和 mLSTM 块,看看如何将它们堆叠在一起以开发 XLSTM 架构。

视觉解释

如果您和我一样,希望以直观的方式解释 XLSTM,请查看本文的 YouTube 视频:

LSTM 复习

最早设计用于处理序列数据的网络之一是循环神经网络。

循环神经网络

它在架构中使用循环连接,以x作为输入,以o作为输出。如果我们将其展开,我们可以将其可视化为在时间戳t-1、tt+1发生的一系列操作。RNN 的一个主要缺点是梯度消失问题,即当我们将太多块堆叠在一起时,梯度会变为零。


提出了 LSTM(长短期记忆网络)来通过在网络中引入单元状态和门控机制来克服梯度消失问题。

LSTM 的简化图示

细胞状态c是跨多个时间戳的长期记忆。隐藏状态h是从一个时间步骤传递到另一个时间步骤的短期记忆。当然,我们有来自输入序列的输入z


三个门具有S 形函数。遗忘门使用 S 形函数来决定在长期记忆中忘记哪些信息。输入门也使用 S 形函数来处理输入并将其添加到遗忘门的输出。在 XLSTM 论文和学术文献中,这种加法运算有一个花哨的术语,称为恒定误差旋转。这种加法运算解决了 RNN 中发现的梯度消失问题。然后,输出c_t输出门处理,输出门通常是 tanh 函数,导致隐藏状态输出 h_t 传递到下一步。


通过这些操作,我们剖析了 LSTM 的两个主要方程,即c_th_t

缺点 1——修改存储决策

LSTM 的主要缺点之一是无法修改存储决策。这意味着,随着序列长度的增加,模型应该能够决定是否将过去的信息保留在内存中。


例如,如果我们将这句话“汤姆去商店了。他买了一些饮料”,并将其与“汤姆去商店买了一些杂货,包括胡萝卜、洋葱、香蕉、苹果、橙子、咖啡和面包。他还买了一些饮料”进行比较。对于每个新词,例如香蕉或苹果,模型必须不断修改是否应该将过去的单词“汤姆”保留在记忆中。这对 LSTM 来说是一个巨大的挑战,它源于其遗忘门的 S 形函数。

S 型函数与指数函数。S 型函数在末端趋于平坦,而指数函数则只会不断增加。

因此,如果我们采用遗忘门,它由 S 形函数组成,该函数具有 S 形曲线,曲线在末端趋于平缓。这表明,随着输入值越来越高,决定忘记什么和保留什么变得非常具有挑战性。但如果我们用指数函数代替它,那么游戏就会发生变化,随着输入值越来越高,输出范围会越来越广。这反过来表明 LSTM 可以更好地修改存储决策。

解决方案 1 — sLSTM

因此,本文提出的解决方案是系统块。如果我们回到表示细胞状态的经典 LSTM 方程,如前所述,它是遗忘门和输入门的函数。

这些门又由 S 型函数组成。那么,如果我们用指数函数替换这些 S 型函数会怎么样?新的门f_ti_t现在变成了exp(f_t)exp(i_t),这几乎就是创建 sLSTM 块的主要修改。


与将输入压缩在固定范围内的 S 型函数不同,指数函数的值往往会随着输入的增加而激增,并且它不会像 S 型函数那样自然地将输出标准化为介于 0 和 1 之间。


因此,我们需要引入一个新的归一化器状态,它是遗忘门和输入门的函数。我们可以将其视为归一化值的运行平均值。

我们使用计算出的标准化值来标准化输出或新的隐藏状态。


虽然规范化处理了隐藏状态,但为了控制指数函数不破坏遗忘门和输入门,我们需要引入一个稳定器。它以对函数的形式出现,以抵消指数函数的影响并引入稳定性。因此,稳定器状态是遗忘门和输入门输出对数的最大值。我们从输入和遗忘门中减去这些稳定器值以稳定它们。

缺点 2——内存和并行化

LSTM 的第二个缺点是缺乏并行化。LSTM 被设计用于处理顺序数据,这意味着它需要处理序列中前一个输入的输出来处理序列中的当前输入。这个特殊的缺点阻碍了并行化,也是导致 Transformers 时代来临的罪魁祸首。

本文提出的解决方案是新颖的 mLSTM 模块。接下来让我们看看它们。

解决方案——mLSTM

XLSTM 的下一个构建块是 mLSTM 块,其中 m 代表内存。让我们再次回到经典的 LSTM 方程,看看它的缺点是什么。我们可以看到单元状态c_t是一个标量。这意味着当我们拥有至少 12 GB 内存的现代 GPU 时,我们一次只能处理 1 个数字。

mLSTM 块引入矩阵来代替单元状态的标量。回到我们经典的 LSTM 方程,如果我们用矩阵C *_t* 替换c_t ,那么单元状态现在变为大写C *_t* 以表示矩阵,并且单元状态不仅可以通过门i_t来检索,还可以通过存储作为向量的键值对来检索。其值可以通过相同维度的向量查询来检索。


为了让它听起来熟悉 transformer 的术语,他们在这里引入了键和值来形成这个矩阵。

超快STM

了解了 sLSTM 和 mLSTM 的信息后,让我们深入了解 XLSTM 的详细架构。

长短期记忆模型

sLSTM 模块的详细介绍

对于 sLSTM,我们使用后上投影。因此,输入首先通过具有 swish 激活函数的因果卷积层。然后,这些层的输出通过具有四个对角块或“头”的块对角线性层。然后,这些输出通过具有四个头的 sLSTM 块。最后,使用具有 GeLU 激活的门控 MLP 层对输出进行上投影,并使用门控 MLP 函数对输出进行下投影。

扫描隧道显微镜

继续介绍 mLSTM 块的细节,我们使用预上投影。这意味着输入首先以投影因子 2 进行上投影。其中一个投影输出进入 mLSTM,另一个进入输出门。mLSTM 块的输入经过因果卷积,然后经过块大小为 4 的块对角投影矩阵,输出 mLSTM 块随时可用的查询、键和值。

XLSTM 架构

最后,我们可以堆叠这两种类型的块来形成扩展的 LSTM 架构。因此,深灰色块是 mLSTM 块,浅灰色块是 sLSTM 块。

在优点方面,论文提到XLSTM网络具有关于序列长度的线性计算复杂度和恒定的内存复杂度。

评估

作者在 SlimPajama 数据集上进行了训练,以将其与其他基于 Transformer 的方法(如 LLAMA)和基于状态空间的方法(如 MAMBA)进行比较。他们使用了 xLSTM a:b 的表示法,其中a是 mLSTM 块的数量, b是堆栈中的 sLSTM 块的数量。

就准确度而言,他们通过将准确度在 0 到 1 之间缩放来报告相对准确度,其中 0 表示随机,1 表示完美。

评估表明,XLSTM 在 Parity 等任务中表现更好,而 Llama 和 Mamba 表现较差。

从结果来看,特别有趣的是奇偶校验任务,其中转换器或状态空间模型在没有内存混合或状态跟踪的情况下往往会遇到困难。我们可以看到,在这种任务中,当我们同时使用 sLSTM 和 mLSTM 块时,xLSTM 的准确率达到 1。

他们还进行了一些消融研究来展示 XLSTM 的稳健性。从论文中很容易理解。此外,本文更多的是关于 XLSTM 的架构新颖性,所以我不会在这里讨论实验结果。

喊出来

如果你喜欢这篇文章,为什么不关注我呢推特我每周每天都会在哪里分享来自顶尖人工智能实验室的研究更新?

也请订阅我的YouTube 频道我在这里以直观的方式解释人工智能概念和论文。

结论

希望本文能够简化并帮助您理解XLSTM架构,了解我们为什么需要它们,以及它们如何在不久的将来取代Transformer。

让我们拭目以待,看看他们准备了什么。下期再见……