신경망 에 대한 기계론적 해석 가능성 의제의 정신에 따라 이 게시물(및 시리즈의 다른 게시물)은 좁은 기술적 작업("유효한 괄호"의 수정된 버전)을 다루기 위해 변환기 모델이 학습한 "알고리즘"을 조사합니다. 리트코드 문제입니다.
작업의 유용성은 LLM에서 기대할 수 있는 보다 일반적인 다음 토큰 예측보다 범위가 훨씬 적지만, 이 연습은 일반적으로 배포되는 초기 직관, 조사 도구 및 일반적인 인식론적 방법론 중 일부를 탐색하는 데 도움이 될 것입니다. 모델이 무엇을 하고 있는지(그리고 우리가 모델이 무엇인지 어떻게 아는지) 알아보세요.
ARENA Mechinterp 월간 챌린지는 이 게시물에 큰 영향을 미쳤으며 첫 번째 문제 세트는 거기에서 나올 것입니다. ( 프로그램을 꼭 확인하셔야 합니다.)
Leetcode에서 볼 수 있는 유효한 괄호 문제:
작업에 사용할 문제에 대한 일부 수정된 제약 조건은 다음과 같습니다.
이렇게 하면 "([)]"와 같은 사례를 처리할 필요가 없습니다.
예
“(((())))” → 유효
“()()()(” → 유효하지 않음
“)()()()(” → 유효하지 않음
def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0
실패 사례에 대한 참고 사항:
시퀀스 끝에서nesting_length ≠ 0
“()()()((” → 유효하지 않음
이를 위해 마지막으로 열린 괄호에 동반 괄호가 없는 것을 볼 때 마지막까지 잘못된 것이 무엇인지 분명하지 않습니다. 주목해야 할 점은 시퀀스가 끝날 때까지 무언가 잘못되었음을 알 수 있는 충분한 정보가 있는 지점이 없다는 것입니다.
시퀀스의 어느 지점에서나 중첩_깊이 < 0
예: “())()()(” → 유효하지 않음
반면에 이 경우에는 시퀀스의 유효성이 복구 불가능하다는 것을 알 수 있는 세 번째 위치의 정보가 충분하므로 조기 종료라고 할 수 있습니다.
주목해야 할 점은 이 예제는 마지막의 nesting_depth
0이었기 때문에 첫 번째 실패 테스트를 통과했다는 것입니다. 따라서 이 테스트 사례는 우리가 일찍 중지하는 데 도움이 될 뿐만 아니라 매우 중요합니다. 테스트 2를 통과한 첫 번째 실패 사례에도 동일하게 적용됩니다.
이제 우리는 자동 회귀 변환기 모델이 문제를 똑같은 방식으로 해결할 것이라고 기대하지 않습니다. 아키텍처가 시퀀스를 한 번 반복하고 모든 것이 올바른지 확인하는 것과 약간 다른 메커니즘을 제공하기 때문입니다. 그러나 우리는 변환기 아키텍처(및 기타 시퀀스 처리 아키텍처)가 최소한 시퀀스의 모든 요소에 대한 정보를 검색 하고 처리 할 수 있다는 것을 확실히 알고 있습니다. 해결책은 다르게 보일 수 있지만 문제의 구조는 동일하며 알려진 것과 시퀀스의 어디에 있는지에 대한 엄격한 경계는 루프, if 문 또는 자기 집합이든 계속해서 사실이라는 점을 기억하는 것이 중요합니다. -주의 집중 및 MLP 비선형성.
흥미로운 질문은 이 아키텍처가 이 정보를 어떻게 활용하는지, 그리고 기존 도구로 쉽게 식별할 수 있는지입니다. 어떤 아키텍처에서든 충분히 성능이 뛰어난 솔루션이 적어도 위의 두 가지 실패 사례를 테스트하지 않는 것은 불가피 하기 때문입니다.
이것은 장난감 문제의 장점 중 하나입니다. 우리는 곧 보게 될 조사에 정보를 제공하는 데 도움이 될 수 있는 이러한 엄격한 보증을 통해 충분히 이해된 좁은 작업을 수행합니다.
데이터 생성을 통해 우리가 추구하는 몇 가지 목표 특성은 다음과 같습니다.
같은 수의 밸런스 스트링과 언밸런스 스트링.
홀수 길이의 스트링은 분명히 불균형이기 때문에 스트링의 길이는 짝수입니다. 이는 모델이 학습하기에 매우 흥미로운 경험적 방법이 아닐 것입니다.
모든 문자열 길이(2-40)는 확률이 동일해야 합니다.
주어진 문자열 길이에 대해 모든 잠재적인 괄호 중첩 깊이는 동일해야 합니다.
공통 주제는 명백합니다. 우리는 주어진 방향에서 편향을 줄이고 견고성을 보장하며 모델 옵션으로 명백한 빠른 승리 휴리스틱을 거부하기 위해 생각할 수 있는 모든 분포 통계를 동등하게 만들려고 노력하고 있습니다. 실패 사례를 생성하기 위해 먼저 위에 나열된 보장을 사용하여 유효한 괄호를 생성한 다음 그 중 절반을 변경하여 불균형을 얻습니다.
from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math
def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0
assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False
괄호 생성의 첫 번째 시도는 무작위 이동을 수행합니다. 그러나 아래 그림에서 볼 수 있듯이 불균형 괄호의 부분 공간은 균형 괄호보다 훨씬 큽니다. 그래서 우리는 확률론을 다르게 도입해야 할 것입니다.
PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens
random_parens = get_random_walk_parens(1000, (2, 10))
random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']
is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]
fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show()
균형 잡힌 괄호 문자열의 구성을 중첩된 괄호의 개별 단위로 분해할 수 있습니다. 이 탐욕적 구성의 경우 문자열 생성 프로세스의 각 단계에서 실행 가능한 깊이 바구니에서 중첩 깊이가 선택됩니다(대상 문자열 길이를 고려하기 위해).
예를 들어 대상 길이가 6
인 경우 다음과 같은 고유한 중첩 분해가 가능합니다.
-> [2, 1], [1, 2], [1,1,1] or [3]
Corresponding to:
-> (())(), ()(()), ()()(), ((()))
def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'
def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'
def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10
def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences
균형 잡힌 파렌을 얻으세요
random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]
중첩 깊이의 빈도를 살펴보겠습니다.
depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}
depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show()
이제 길이 주파수를 살펴보겠습니다.
paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show()
데이터 배포의 다음과 같은 잠재적 속성 사이에는 긴장이 있습니다.
이는 위의 플롯에 표시된 것처럼 낮은 중첩 깊이 하위 시퀀스가 주어진 무작위 중첩 시퀀스에 나타날 더 많은 기회를 갖기 때문입니다.
순전히 무작위 시퀀스의 이러한 자연스러운 경향에 대응하기 위해 주어진 괄호의 하위 문자열을 생성할 때 왜곡된 분포에서 샘플링하여 더 깊은 중첩 값을 더 가능성 있게 만들 수 있습니다.
이것은 훈련의 첫 번째 통과 후에 다시 논의될 것입니다.
px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show()
우리 데이터 세트는 균형 잡힌 괄호만 가질 수 없습니다. 따라서 균형 잡힌 데이터 세트에서 불균형 문자열을 파생시키는 데이터 생성 전략을 만들 수 있습니다.
def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0
def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))
불균형한 Parens 데이터세트 가져오기
unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]
이제 데이터 세트가 있으므로 재미삼아 Transformer 아키텍처를 처음부터 작성하겠습니다.
먼저 일부 구성
@dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1
그런 다음 입력 구문 분석을 위한 토크나이저:
class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)
(비)임베딩
class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x
핸디 레이어 표준
class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias
그리고 마지막으로 주목!
class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)
MLP 레이어
class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)
이를 트랜스포머에 합치면
class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask
class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)
훈련 유틸리티
def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss
훈련 구성
cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)
실제 훈련
batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}')
3부에서는 훈련된 네트워크의 내부를 조사합니다. 우리는 주의 패턴을 살펴보고 활성화 패치와 같은 기계적 해석 가능성의 일부 진단 도구를 적용하여 네트워크가 이 작업을 어떻게 해결했는지 이해하는 기계적 모델을 구축함으로써 이를 수행할 것입니다.
여기까지 읽어주셔서 감사합니다. 곧 3부에서 만나보실 수 있습니다!