LSTM(Long Short-Term Memory Network)은 오랫동안 사용되어 왔습니다. 이는 텍스트 생성 및 번역, 심지어 이미지 캡션 생성과 같은 꽤 많은 시퀀스 관련 작업에 적용되었습니다.
이들의 단점은 강력한 최신 GPU를 활용하기 위해 병렬화할 수 없다는 것입니다. 이러한 제한으로 인해 훈련 및 추론의 대규모 병렬화를 위해 GPU를 활용하는 변환기가 출현할 수 있는 길이 열렸습니다.
이제 LSTM을 개선하고 병렬화하려고 시도하면 차세대 LLM을 구축하기 위한 도구가 될 수 있습니까?
이것은 " 확장된 " 장기 단기 기억을 의미하는 " XLSM - Extended Long Short-term Memory Networks " 논문에서 답변한 정확한 질문입니다. 그들은 아키텍처에서 sLSTM과 mLSTM이라는 두 가지 새로운 블록을 제안함으로써 이를 수행합니다.
이제 본 백서에서 제안된 sLSTM 및 mLSTM 블록을 자세히 살펴보고 이를 함께 쌓아서 XLSTM 아키텍처를 개발할 수 있는 방법을 살펴보겠습니다.
나와 같은 사람이고 XLSTM을 시각적으로 설명하고 싶다면 이 기사의 YouTube 비디오를 확인하십시오.
순차 데이터를 처리하도록 설계된 최초의 네트워크 중 하나는 순환 신경망(Recurrent Neural Network)입니다.
x를 입력으로, o를 출력으로 사용하여 아키텍처에서 반복 연결을 사용합니다. 이를 펼치면 타임스탬프 t-1, t 및 t+1 에서 발생하는 일련의 작업으로 시각화할 수 있습니다. RNN의 주요 단점은 너무 많은 블록을 쌓을 때 그래디언트가 0이 되는 Vanishing Gradient 문제였습니다.
LSTM(Long Short-term Memory Network)은 셀 상태와 게이팅 메커니즘을 네트워크에 도입하여 소실 기울기를 극복하기 위해 제안되었습니다.
셀 상태 c 는 여러 타임 스탬프에 걸쳐 존재하는 장기 기억입니다. 숨겨진 상태 h 는 한 단계에서 다른 단계로 전달되는 단기 기억입니다. 그리고 물론 입력 시퀀스의 입력 z도 있습니다.
세 개의 게이트에는 S자 모양의 기능이 있습니다. 망각 게이트는 시그모이드 함수를 사용하여 장기 기억에서 어떤 정보를 잊어버릴지 결정합니다. 입력 게이트는 또한 시그모이드 함수를 사용하여 입력을 처리하고 이를 망각 게이트의 출력에 추가합니다. 이 추가 작업에는 XLSTM 논문과 학술 문헌에 지속적인 오류 캐러셀(constant error carousal )이라는 멋진 용어가 있습니다. 이 추가 작업은 RNN에서 발견되는 Vanishing Gradient 문제를 해결하는 것입니다. 그런 다음 출력 c_t는 출력 게이트 에 의해 처리됩니다. 이는 일반적으로 다음 단계로 전달되는 숨겨진 상태 출력 h_t로 이어지는 tanh 함수입니다.
이러한 작업을 통해 우리는 LSTM의 두 가지 주요 방정식인 c_t 및 h_t를 분석했습니다.
LSTM의 주요 단점 중 하나는 스토리지 결정을 수정할 수 없다는 것입니다. 이것이 의미하는 바는 시퀀스 길이가 증가함에 따라 모델이 과거 정보를 메모리에 유지할지 여부를 결정할 수 있어야 한다는 것입니다.
예를 들어, “Tom은 가게에 갔다. 그는 음료수를 샀습니다.” 그리고 이를 “Tom은 당근, 양파, 바나나, 사과, 오렌지, 커피, 빵이 포함된 식료품을 사러 가게에 갔습니다. 그 사람은 술도 좀 샀어요.” 바나나나 사과와 같은 모든 새로운 단어에 대해 모델은 과거 단어 "Tom"을 메모리에 유지해야 하는지 여부를 지속적으로 수정해야 합니다. 이는 LSTM에 대한 큰 도전이며 이는 망각 게이트의 시그모이드 기능에서 비롯됩니다.
시그모이드 함수와 지수 함수. 시그모이드는 끝으로 갈수록 평평해지지만 지수는 계속해서 증가합니다.
따라서 망각 게이트를 사용하면 끝으로 갈수록 편평해지는 S자형 곡선을 갖는 시그모이드 함수로 구성됩니다. 이는 입력 값이 높아질수록 무엇을 잊어야 하고 무엇을 기억에 남겨야 할지 결정하는 것이 상당히 어려워진다는 것을 의미합니다. 그러나 그 자리에 지수 함수를 사용하면 게임이 바뀌고 입력 값이 높아질수록 출력 범위가 넓어집니다. 이는 결국 LSTM이 스토리지 결정을 더 잘 수정할 수 있음을 나타냅니다.
따라서 본 논문에서 제안하는 솔루션은 시스템 블록이다. 앞서 본 것처럼 셀 상태를 나타내는 고전적인 LSTM 방정식으로 돌아가면 이는 망각 게이트와 입력 게이트의 함수입니다.
이 게이트는 시그모이드 함수로 구성됩니다. 그렇다면 이러한 시그모이드 함수를 지수 함수로 대체하면 어떻게 될까요? 새로운 게이트 f_t 및 i_t는 이제 exp(f_t) 및 exp(i_t) 가 되며 이는 sLSTM 블록을 생성하기 위한 주요 수정 사항입니다.
입력을 고정된 범위에 두는 시그모이드 함수와 달리 지수 함수는 입력이 증가함에 따라 값이 폭발하는 경향이 있으며 시그모이드처럼 출력을 0과 1 사이에 있도록 자연스럽게 정규화하지 않습니다. 기능.
따라서 망각 게이트와 입력 게이트의 기능인 새로운 정규화 상태를 도입해야 합니다. 이를 정규화 값의 실행 평균으로 생각할 수 있습니다.
계산된 정규화 값을 사용하여 출력 또는 새로운 숨겨진 상태를 정규화합니다.
정규화가 숨겨진 상태를 처리하는 동안 망각 및 입력 게이트가 폭발하는 것을 방지하기 위해 안정기를 도입해야 합니다. 지수 효과에 대응하고 안정성을 도입하기 위해 로그 함수 형태로 제공됩니다. 따라서 스태빌라이저 상태는 망각 게이트 로그와 입력 게이트 출력의 최대값입니다. 입력에서 이러한 안정기 값을 빼고 이를 안정화하기 위해 게이트를 잊어버립니다.
LSTM의 두 번째 단점은 병렬화가 부족하다는 것입니다. LSTM은 순차 데이터를 처리하도록 설계되었습니다. 즉, 시퀀스의 현재 입력을 처리하려면 시퀀스의 이전 입력을 처리하는 출력이 필요하다는 의미입니다. 이 특별한 결점은 병렬화를 방해하고 트랜스포머 시대를 여는 주범이었습니다.
본 논문에서 제안하는 솔루션은 새로운 mLSTM 블록이다. 그럼 다음에 그것들을 살펴보겠습니다.
XLSTM의 다음 빌딩 블록은 mLSTM 블록입니다. 여기서 m은 메모리를 나타냅니다. 다시 고전적인 LSTM 방정식으로 돌아가서 단점이 무엇인지 살펴보겠습니다. 셀 상태 c_t가 스칼라임을 알 수 있습니다. 이는 최소 12GB의 메모리를 갖춘 최신 GPU를 사용할 때 한 번에 하나의 숫자만 처리한다는 의미입니다.
mLSTM 블록은 셀 상태에 대한 스칼라 대신 행렬을 도입합니다. LSTM의 고전 방정식으로 돌아가면, c_t를 행렬 C *_t*로 대체하여 셀 상태가 이제 행렬을 나타내는 대문자 C *_t*가 되고 셀 상태는 게이트 i_t 뿐만 아니라 벡터인 키-값 쌍을 저장하여 동일한 차원의 벡터인 쿼리로 검색할 수 있는 값입니다.
변환기의 용어에 친숙하게 들리도록 하기 위해 여기에 키와 값을 도입하여 이 매트릭스를 형성했습니다.
sLSTM 및 mLSTM에 대한 정보를 바탕으로 XLSTM의 세부 아키텍처를 살펴보겠습니다.
sLSTM의 경우 사후 예측을 사용합니다. 따라서 입력은 먼저 스위시 활성화 함수를 사용하여 인과 컨볼루션 레이어를 통과합니다. 그런 다음 이러한 레이어의 출력은 4개의 대각선 블록 또는 "헤드"가 있는 블록-대각선 선형 레이어를 통해 공급됩니다. 그런 다음 이들의 출력은 4개의 헤드가 있는 sLSTM 블록을 통해 공급됩니다. 마지막으로 출력은 GeLU 활성화가 포함된 게이트 MLP 계층을 사용하여 상향 투영되고 게이트 MLP 기능을 사용하여 하향 투영됩니다.
mLSTM 블록의 세부 사항으로 이동하여 사전 예측을 사용합니다. 이는 입력이 먼저 투영 계수 2로 상향 투영된다는 의미입니다. 투영 출력 중 하나는 mLSTM으로 이동하고 다른 하나는 출력 게이트로 이동합니다. mLSTM 블록에 대한 입력은 인과 컨벌루션을 거친 다음 블록 크기 4의 블록 대각선 투영 행렬을 통해 mLSTM 블록에서 쉽게 사용되는 쿼리, 키 및 값을 출력합니다.
마지막으로 두 가지 유형의 블록을 쌓아 확장된 LSTM 아키텍처를 형성할 수 있습니다. 따라서 어두운 회색 블록은 mLSTM 블록이고 밝은 회색 블록은 sLSTM 블록입니다.
장점 측면에서, 이 논문에서는 XLSTM 네트워크가 선형 계산 복잡성과 시퀀스 길이와 관련된 지속적인 메모리 복잡성을 가지고 있다고 언급합니다.
저자는 SlimPajama 데이터 세트를 훈련하여 LLAMA와 같은 다른 변환기 기반 방법 및 MAMBA와 같은 상태 공간 기반 방법과 비교했습니다. 그들은 xLSTM a:b라는 표기법을 사용했습니다. 여기서 a 는 mLSTM 블록의 수이고 b는 스택에 있는 sLSTM 블록의 수입니다.
정확도 측면에서 0과 1 사이의 정확도를 스케일링하여 상대 정확도를 보고합니다. 여기서 0은 무작위이고 1은 완벽합니다.
평가에 따르면 XLSTM은 Parity와 같은 작업에서 더 나은 성능을 보이는 반면 Llama 및 Mamba는 성능이 좋지 않은 것으로 나타났습니다.
결과에서 특히 흥미로운 것은 변환기나 상태공간 모델이 메모리 혼합이나 상태 추적 없이 어려움을 겪는 경향이 있는 패리티 작업입니다. 이런 종류의 작업에서 sLSTM과 mLSTM 블록을 함께 사용할 때 xLSTM의 정확도는 1에 도달하는 것을 볼 수 있습니다.
그들은 또한 XLSTM의 견고성을 보여주기 위해 몇 가지 절제 연구를 수행했습니다. 그들은 종이를 통해 이해하기 쉽습니다. 게다가 이 기사는 XLSTM의 구조적 혁신에 관한 것이므로 여기서는 실험 결과를 다루지 않습니다.
이 기사가 마음에 드셨다면 저를 팔로우해 보세요.
또한 내 구독도 부탁드립니다.
이 기사를 통해 XLSTM 아키텍처, 이것이 필요한 이유, 가까운 미래에 잠재적으로 변압기를 능가할 수 있는 방법에 대한 이해가 단순화되고 쉬워졌기를 바랍니다.
그들이 무엇을 준비하고 있는지 기다려 봅시다. 다음 시간에 뵙겠습니다…