LSTM (Long Short-Term Memory Networks) は、かなり以前から存在しています。テキストの生成や翻訳、さらには画像キャプションの生成など、シーケンス関連のタスクに数多く適用されてきました。
それらの欠点は、強力な最新の GPU を利用するために並列化できなかったことです。この制限により、トレーニングと推論の大規模な並列化に GPU を活用するトランスフォーマーの登場への道が開かれました。
今、LSTM を改良して並列化しようとすれば、それらは次世代のLLMを構築するためのツールになるでしょうか?
これはまさに論文「 XLSM — Extended Long Short-term Memory Networks 」で答えられている質問です。これは「拡張された」Long short-term memory の略です。論文では、アーキテクチャに sLSTM と mLSTM という 2 つの新しいブロックを提案することで答えています。
それでは、この論文で提案されている sLSTM ブロックと mLSTM ブロックを詳しく調べ、それらをどのように組み合わせて XLSTM アーキテクチャを開発できるかを見てみましょう。
私のように XLSTM を視覚的に説明してほしい場合は、この記事の YouTube ビデオを確認してください。
連続データを処理するように設計された最も初期のネットワークの 1 つが、リカレント ニューラル ネットワークです。
RNN のアーキテクチャでは、 xを入力、 oを出力とする再帰接続を使用します。展開すると、タイムスタンプt-1、t、 t+1で発生する一連の操作として視覚化できます。RNN の主な欠点は、ブロックを積み重ねすぎると勾配がゼロになる勾配消失問題でした。
LSTM (Long short-term memory networks) は、ネットワークにセル状態とゲーティング メカニズムを導入することで消失勾配を克服するために提案されました。
セル状態cは、複数のタイムスタンプにまたがって存続する長期記憶です。隠れ状態hは、あるタイムステップから別のタイムステップに渡される短期記憶です。そして、もちろん、入力シーケンスからの入力zがあります。
3 つのゲートにはS 字型の関数があります。忘却ゲートはシグモイド関数を使用して、長期記憶でどの情報を忘れるかを決定します。入力ゲートもシグモイド関数を使用して入力を処理し、それを忘却ゲートの出力に追加します。この追加操作は、XLSTM 論文および学術文献では定数エラーカルーサルと呼ばれる専門用語で呼ばれています。この追加操作は、RNN で見られる勾配消失問題に対処するものです。出力c_tは次に出力ゲートによって処理されます。これは通常、次のステップに渡される隠れ状態出力 h_t につながる tanh 関数です。
これらの操作により、 LSTM の 2 つの主要な方程式であるc_tとh_tを分析しました。
LSTM の主な欠点の 1 つは、保存の決定を修正できないことです。つまり、シーケンスの長さが長くなると、モデルは過去の情報をメモリに保持するかどうかを決定できる必要があります。
たとえば、「トムは店に行きました。彼は飲み物を何杯か買いました」という文を、「トムは店に行って、ニンジン、玉ねぎ、バナナ、リンゴ、オレンジ、コーヒー、パンなどの食料品を買いました。彼は飲み物も何杯か買いました」という文と比較すると、バナナやリンゴなどの新しい単語ごとに、モデルは過去の単語「トム」をメモリに保持するかどうかを常に修正する必要があります。これは LSTM にとって大きな課題であり、忘却ゲートのシグモイド関数に起因しています。
シグモイド関数と指数関数。シグモイド関数は端に向かって平らになりますが、指数関数は増加し続けます。
したがって、忘却ゲートはシグモイド関数で構成され、S 字型の曲線が末尾に向かって平坦になります。これは、入力の値が大きくなるにつれて、何を忘れて何をメモリに保持するかの決定が非常に困難になることを示しています。ただし、代わりに指数関数を使用すると状況が変わり、入力の値が大きくなるにつれて、出力の範囲が広くなります。これは、LSTM がストレージの決定を修正する能力が向上することを示しています。
そこで、この論文で提案されている解決策はシステム ブロックです。セルの状態を表す古典的な LSTM 方程式に戻ると、前に見たように、それは忘却ゲートと入力ゲートの関数です。
これらのゲートは、シグモイド関数で構成されています。では、これらのシグモイド関数を指数関数に置き換えたらどうなるでしょうか。新しいゲートf_tとi_t は、 exp(f_t)とexp(i_t) になり、これが sLSTM ブロックを作成するための主な変更点です。
入力を固定範囲に収めるシグモイド関数とは異なり、指数関数は入力が増加するにつれて値が急激に増加する傾向があり、シグモイド関数のように、出力が 0 と 1 の間に自然に正規化されることはありません。
したがって、忘却ゲートと入力ゲートの関数である新しい正規化状態を導入する必要があります。これは、正規化値の実行平均と考えることができます。
計算された正規化値を使用して、出力または新しい隠し状態を正規化します。
正規化によって隠れ状態が処理される一方で、指数関数が忘却ゲートと入力ゲートを爆発させないように制御するために、安定器を導入する必要があります。これは、指数関数の影響を打ち消し、安定性を導入するための対数関数の形で提供されます。したがって、安定器の状態は、忘却ゲートと入力ゲートの出力の対数の最大値です。入力ゲートと忘却ゲートからこれらの安定器の値を減算して、ゲートを安定させます。
LSTM の 2 つ目の欠点は、並列化ができないことです。LSTM は連続データを処理するように設計されているため、シーケンス内の現在の入力を処理するには、シーケンス内の前の入力の処理の出力が必要です。この欠点により並列化が妨げられ、Transformers 時代の幕開けにつながった原因となりました。
この論文で提案されている解決策は、新しい mLSTM ブロックです。それでは次に、それについて見ていきましょう。
XLSTM の次の構成要素は mLSTM ブロックです。ここで、m はメモリを表します。古典的な LSTM 方程式に戻って、その欠点を見てみましょう。セル状態c_t はスカラーであることがわかります。つまり、少なくとも 12 GB のメモリを備えた現代の GPU を利用できる場合、一度に処理できる数値は 1 つだけです。
mLSTM ブロックは、セル状態のスカラーの代わりに行列を導入します。 LSTM の従来の方程式に戻ると、 c_t を行列C *_t* に置き換えて、セル状態が行列を示す大文字のC *_t* になり、セル状態はゲートi_tだけでなく、ベクトルであるキーと値のペアを保存することによって取得できるようになります。その値は、同じ次元のベクトルであるクエリによって取得できます。
トランスフォーマーの用語に馴染みやすいように、このマトリックスを形成するために、ここではキーと値が導入されています。
sLSTM と mLSTM に関する情報を踏まえて、XLSTM の詳細なアーキテクチャについて詳しく見ていきましょう。
sLSTM に関しては、ポストアップ投影を使用します。したがって、入力は最初に、スウィッシュ活性化関数を持つ因果畳み込み層を通過します。これらの層からの出力は、4 つの対角ブロックまたは「ヘッド」を持つブロック対角線形層を通過します。これらの出力は、4 つのヘッドを持つ sLSTM ブロックを通過します。最後に、出力は、GeLU 活性化を持つゲート MLP 層を使用してアップ投影され、ゲート MLP 関数を使用してダウン投影されます。
mLSTM ブロックの詳細に移ると、事前アップ投影を使用します。つまり、入力は最初に投影係数 2 でアップ投影されます。投影出力の 1 つは mLSTM に送られ、もう 1 つは出力ゲートに送られます。mLSTM ブロックへの入力は因果畳み込みを通過し、次にブロック サイズ 4 のブロック対角投影行列を通過し、mLSTM ブロックですぐに使用できるクエリ、キー、および値が出力されます。
最後に、2 種類のブロックを積み重ねて、拡張 LSTM アーキテクチャを形成します。濃い灰色のブロックは mLSTM ブロックで、薄い灰色のブロックは sLSTM ブロックです。
利点に関して、論文では、XLSTM ネットワークはシーケンスの長さに関して線形の計算複雑度と一定のメモリ複雑度を持つと述べています。
著者らは SlimPajama データセットでトレーニングを行い、それを LLAMA などの他のトランスフォーマーベースの方法や MAMBA などの状態空間ベースの方法と比較しました。著者らは xLSTM a:b という表記を使用しました。ここで、 aは mLSTM ブロックの数、 b はスタック内の sLSTM ブロックの数です。
精度に関しては、0 がランダム、1 が完全として、0 と 1 の間で精度をスケーリングして相対的な精度を報告します。
評価によると、XLSTM は Parity などのタスクで優れたパフォーマンスを発揮しますが、Llama と Mamba のパフォーマンスは低くなります。
結果から特に興味深いのはパリティ タスクです。このタスクでは、メモリ混合や状態追跡なしでは、トランスフォーマーや状態空間モデルが苦戦する傾向があります。この種のタスクでは、sLSTM ブロックと mLSTM ブロックの両方を一緒に使用すると、xLSTM の精度が 1 になることがわかります。
彼らはまた、XLSTM の堅牢性を示すためにアブレーション研究も行いました。論文から簡単に理解できます。また、この記事は主に XLSTM のアーキテクチャ上の新規性に関するものであるため、ここでは実験結果には触れません。
この記事が気に入ったら、ぜひフォローしてください
また、私のチャンネルを購読してください
この記事が、XLSTM アーキテクチャ、それがなぜ必要なのか、そして近い将来にトランスフォーマーに取って代わる可能性がある理由についての理解を簡素化し、容易にしてくれたことを願っています。
彼らが何を用意しているか、楽しみに待ちましょう。次回もお会いしましょう…