paint-brush
LLM 対 Leetcode (パート 1 と 2): アルゴリズムの問題に対する Transformer のソリューションを理解する@boluben
20,299 測定値
20,299 測定値

LLM 対 Leetcode (パート 1 と 2): アルゴリズムの問題に対する Transformer のソリューションを理解する

Bolu Ben-Adeola16m2024/04/16
Read on Terminal Reader

長すぎる; 読むには

この記事シリーズでは、Transformer モデルの解釈可能性について掘り下げ、有効な括弧の問題に取り組むことで、Transformer モデルがどのようにアルゴリズムを学習するかを調査します。データ生成、モデル トレーニングを取り上げ、パート 3 では注目パターンとメカニズムの理解について詳しく説明します。
featured image - LLM 対 Leetcode (パート 1 と 2): アルゴリズムの問題に対する Transformer のソリューションを理解する
Bolu Ben-Adeola HackerNoon profile picture
0-item
1-item

ニューラル ネットワーク機械的解釈可能性の課題の精神に沿って、この投稿 (およびおそらくシリーズの他の投稿) では、限定された技術的タスク (「有効な括弧」Leetcode 問題の修正版) に取り組むためにトランスフォーマー モデルによって学習された「アルゴリズム」を調査します。


このタスクの有用性は、LLM で期待されるより一般的な次のトークンの予測よりも範囲がはるかに控えめですが、この演習は、モデルが何をしているのか (そして、それがどのようにわかるのか) を知るために通常展開される初期の直感、調査ツール、および一般的な認識論的方法論のいくつかを調査するのに役立ちます。


ARENA Mechinterp の月例チャレンジはこの投稿に大きな影響を与えており、最初の問題セットはそこから生まれます。(ぜひプログラムをチェックしてください。)


シリーズ構成:

  1. Leetcode の問題をタスクとして選択します。(パート 1)
  2. これに最小限の実行可能な Transformer モデルをトレーニングします。(パート 2)
  3. モデルが何を学習したかを調べます。(パート 3)


パート1: 問題

Leetcode で見られる有効な括弧の問題:



タスクで使用する問題に対するいくつかの変更された制約:


  • 使用できる文字は「(」と「)」のみです。
    • これにより、「([)]」のようなケースを処理する必要がなくなります。


  • 入力シーケンスの最大長は 40 文字です。
    • 迅速な反復のためにモデルを小さく保つため。



“(((())))” → 有効

「()()()(」 → 無効

“)()()()(” → 無効



バニラソリューション

def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


失敗事例に関する注意事項:


  1. シーケンスの最後で nesting_depth ≠ 0

    「()()()((」 → 無効


    この場合、最後の開いた括弧に括弧が付いていないことが最後までわかるまで、何かが間違っていることは明らかではありません。注目すべき点は、最後まで、何かが間違っているとわかる十分な情報があった箇所がシーケンス内になかったことです。


  2. シーケンスのどの時点でも nesting_depth < 0

    例: “())()()(” → 無効


    一方、この場合は、3 番目の位置までにシーケンスの有効性が回復不可能であることがわかる十分な情報があるため、早期に終了することができます。


    注目すべき点は、最後のnesting_depthが 0 に等しくなるため、この例では最初の失敗テストに合格するということです。したがって、このテスト ケースは早期停止に役立つだけでなく、非常に重要です。同じことが最初の失敗ケースの例にも当てはまり、テスト 2 に合格します。



さて、自己回帰トランスフォーマー モデルがまったく同じ方法で問題を解決するとは期待していません。そのアーキテクチャでは、シーケンスを 1 回ループしてすべてが正常かどうかを確認するよりもわずかに異なるメカニズムが提供されるためです。ただし、トランスフォーマー アーキテクチャ (およびその他のシーケンス処理アーキテクチャ) は、少なくともシーケンス内のすべての要素に関する情報を検出し処理できることは確かです。ソリューションは異なって見えるかもしれませんが、問題の構造は同じであり、ループと if ステートメントであれ、自己注意スイープと MLP 非線形性のアンサンブルであれ、シーケンス内の既知の情報と場所に関する厳密な境界は引き続き当てはまることを覚えておくことが重要です。


興味深い疑問は、このアーキテクチャがこの情報をどのように活用し、既存のツールで簡単に識別できるかどうかです。なぜなら、どのアーキテクチャでも、十分にパフォーマンスの高いソリューションでは、少なくとも上記の 2 つの障害ケースをテストしないことは避けられないからです。


これは、おもちゃの問題の利点の 1 つです。これらの厳格な保証によって、十分に理解された狭いタスクが得られ、すぐにわかるように、調査に役立つ情報が得られます。


パート2: データとモデル

トレーニングデータの準備

データ生成で目指すターゲット特性は次のとおりです。


  • バランス弦とアンバランス弦が同数。

  • 文字列は偶数の長さになります。奇数の長さの文字列は明らかに不均衡であるため、モデルが学習するヒューリスティックとしてはあまり興味深いものではありません。

  • すべての文字列の長さ (2 ~ 40) は、同じ確率で発生します。

  • 特定の文字列の長さに対して、すべての潜在的な括弧のネストの深さは、同等の確率で発生します。


共通のテーマは明らかです。考えられるすべての分布統計を、特定の方向への偏りを減らす可能性が等しくなれるようにし、堅牢性を確保し、モデルのオプションとして明らかな即効性のあるヒューリスティックを否定しようとしています。失敗ケースを生成するために、まず上記の保証で有効な括弧を生成し、次にその半分を変更して不均衡にします。


 from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math



 def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


 assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False


データ生成スキーム #1: ランダムウォーク

括弧生成の最初の試みは、ランダムウォークを実行するだけです。しかし、下のグラフからわかるように、不均衡な括弧の部分空間は均衡のとれた括弧の部分空間よりもはるかに大きいため、確率性を別の方法で導入する必要があります。


 PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens



 random_parens = get_random_walk_parens(1000, (2, 10))



 random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']



 is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]



 fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show() 






データ生成スキーム #2: 貪欲ランダムネストシーケンス

バランスのとれた括弧文字列の構築を、ネストされた括弧の個別の単位に分解することができます。この貪欲な構築では、文字列を生成するプロセスの各ステップで、ネスト深度が実行可能な深度のバスケットから選択されます (ターゲット文字列の長さを尊重するため)。


たとえば、ターゲットの長さが6場合、次の一意のネスト分解が可能です。


-> [2, 1], [1, 2], [1,1,1] or [3]

Corresponding to:

-> (())(), ()(()), ()()(), ((()))



 def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'



 def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'



 def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10



 def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences



バランスのとれた親子関係を築く

random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]


巣の深さの頻度を見てみましょう


depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}



 depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show() 


歪んだ深度周波数




さて、長さの周波数を見てみましょう。


 paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show() 


弦長周波数はかなり平坦


データ生成に関する注意

データ分布の次の潜在的な特性の間には緊張関係があることに注意してください。


  1. すべての文字列の長さは等しく可能です。
  2. すべてのネストの深さの部分文字列は、すべての文字列にわたって等しく発生します。


これは、上のグラフに示すように、ネストの深さが低いサブシーケンスの方が、特定のランダムなネスト シーケンスに現れる機会が多くなるためです。


純粋にランダムなシーケンスのこの自然な傾向に対抗するには、括弧の特定の部分文字列を生成するときに、より深いネストの値がより起こりやすくなるように偏った分布からサンプリングすることができます。

これは、最初のトレーニングの後に再度検討されます。


 px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show() 




不均衡な括弧の作成

データセットにはバランスの取れた括弧だけを含めることはできません。そのため、バランスの取れたデータセットから不均衡な文字列を導出するデータ生成戦略を作成できます。


 def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0



 def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))


不均衡な括弧データセットを取得する


unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]



モデルトレーニング

これでデータセットができました。楽しみのために、Transformer アーキテクチャをゼロから記述します。


まずは設定


@dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1


次に、入力を解析するためのトークナイザーです。


 class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)


(非)埋め込み


class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x


便利なレイヤーノルム


class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias


そして最後に注意!


 class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)



MLP レイヤー


class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)



トランスフォーマーに組み込む


class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask


 class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)


トレーニングユーティリティ


def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss



トレーニング設定


cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)


実際のトレーニング


batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}') 


トレーニング飽和




パート 3 では、このトレーニングされたネットワークの内部を調査します。注意パターンを調べ、アクティベーション パッチングなどのメカニズム解釈可能性の診断ツールのいくつかを適用して、ネットワークがこのタスクをどのように解決したかを理解するためのメカニズム モデルを構築します。


ここまで読んでいただきありがとうございました。パート 3 でお会いしましょう!