著者:
(1)カーネギーメロン大学機械学習学部のアルバート・グ氏と同等の貢献
(2)プリンストン大学コンピュータサイエンス学部のTri Dao氏と同等の貢献。
3 選択的状態空間モデルと3.1 動機: 圧縮手段としての選択
畳み込み (Krizhevsky、Sutskever、Hinton 2012) やトランスフォーマー (Vaswani et al. 2017) などのハードウェアフレンドリーなアーキテクチャは、幅広く応用されています。ここでは、選択的 SSM を最新のハードウェア (GPU) でも効率的にすることを目指しています。選択メカニズムは非常に自然であり、以前の研究では、再帰型 SSM で ∆ を時間の経過とともに変化させるなど、選択の特殊なケースを組み込む試みがなされました (Gu、Dao、et al. 2020)。ただし、前述のように、SSM の使用における主な制限は計算効率であり、これが S4 およびすべての派生モデルが LTI (非選択的) モデル、最も一般的にはグローバル畳み込みの形式を使用した理由です。
3.3.1 先行モデルの動機
まず、この動機を再検討し、従来の方法の限界を克服するためのアプローチを概観します。
• 大まかに言えば、SSM などの再帰モデルは常に表現力と速度のトレードオフのバランスをとっています。セクション 3.1 で説明したように、隠れ状態の次元が大きいモデルはより効果的ですが、速度は遅くなります。したがって、速度とメモリのコストを支払わずに隠れ状態の次元を最大化する必要があります。
• 再帰モードは畳み込みモードよりも柔軟であることに注意してください。これは、後者 (3) が前者 (2) の拡張から派生しているためです (Gu、Goel、および Ré 2022; Gu、Johnson、Goel、et al. 2021)。ただし、これには、形状 (B、L、D、N) の潜在状態 ℎ を計算して実現する必要があり、これは形状 (B、L、D) の入力 x と出力 y よりもはるかに大きくなります (N 倍、SSM 状態次元)。そのため、状態の計算をバイパスして (B、L、D) のみの畳み込みカーネル (3a) を実現できる、より効率的な畳み込みモードが導入されました。
• 従来のLTI SSMは、二重の再帰畳み込み形式を活用して、効率の低下を招くことなく、有効状態の次元を従来のRNNよりもはるかに大きいNx(≈10-100)倍に増加させます。
3.3.2 選択スキャンの概要: ハードウェア認識型状態拡張
選択メカニズムは、LTI モデルの限界を克服するように設計されています。そのため、同時に、SSM の計算問題を再検討する必要があります。私たちは、カーネル融合、並列スキャン、再計算という 3 つの古典的な手法でこれに対処します。私たちは、主に 2 つの観察を行います。
• 単純な再帰計算では O(BLDN) FLOP を使用し、畳み込み計算では O(BLD log(L)) FLOP を使用しますが、前者の方が定数係数が低くなります。したがって、長いシーケンスとそれほど大きくない状態次元 N の場合、再帰モードでは実際にはより少ない FLOP を使用できます。
• 2 つの課題は、再帰の連続性と大量のメモリ使用量です。後者に対処するには、畳み込みモードと同様に、完全な状態 ℎ を実際に実現しないようにすることができます。
主なアイデアは、最新のアクセラレータ (GPU) の特性を活用して、状態 ℎ をメモリ階層のより効率的なレベルでのみ実現することです。特に、ほとんどの操作 (行列乗算を除く) はメモリ帯域幅によって制限されます (Dao、Fu、Ermon 他 2022 年、Ivanov 他 2021 年、Williams、Waterman、Patterson 2009 年)。これにはスキャン操作も含まれ、カーネル フュージョンを使用してメモリ IO の量を削減することで、標準の実装と比較して大幅な高速化を実現しています。
順次的な再帰を回避するために、線形ではないにもかかわらず、作業効率の高い並列スキャン アルゴリズムを使用して並列化できることがわかります (Blelloch 1990、Martin および Cundy 2018、Smith、Warrington、および Linderman 2023)。
最後に、バックプロパゲーションに必要な中間状態を保存しないようにする必要があります。メモリ要件を削減するために、再計算の古典的な手法を慎重に適用します。中間状態は保存されず、入力が HBM から SRAM にロードされるときに逆方向パスで再計算されます。その結果、融合された選択スキャン レイヤーのメモリ要件は、FlashAttention を使用した最適化されたトランスフォーマー実装と同じになります。
融合カーネルと再計算の詳細については付録 D を参照してください。完全な選択的 SSM レイヤーとアルゴリズムは図 1 に示されています。
この論文は、CC BY 4.0 DEED ライセンスの下でarxiv で公開されています。