paint-brush
LLMs vs. Leetcode (Teil 1 und 2): Transformer-Lösungen für algorithmische Probleme verstehenvon@boluben
20,263 Lesungen
20,263 Lesungen

LLMs vs. Leetcode (Teil 1 und 2): Transformer-Lösungen für algorithmische Probleme verstehen

von Bolu Ben-Adeola16m2024/04/16
Read on Terminal Reader

Zu lang; Lesen

Diese Artikelserie befasst sich mit der Interpretierbarkeit von Transformer-Modellen und untersucht, wie sie Algorithmen lernen, indem sie das Problem der gültigen Klammern lösen. Sie behandelt Datengenerierung und Modelltraining und verspricht in Teil 3 einen detaillierten Einblick in Aufmerksamkeitsmuster und mechanistisches Verständnis.
featured image - LLMs vs. Leetcode (Teil 1 und 2): Transformer-Lösungen für algorithmische Probleme verstehen
Bolu Ben-Adeola HackerNoon profile picture
0-item
1-item

Im Geiste der mechanistischen Interpretierbarkeitsagenda für neuronale Netzwerke untersucht dieser Beitrag (und möglicherweise weitere, die in einer Reihe folgen) die von einem Transformer-Modell erlernten „Algorithmen“ zum Bewältigen einer engen technischen Aufgabe – einer modifizierten Version des Leetcode-Problems „Gültige Klammern“.


Obwohl der Nutzen dieser Aufgabe im Vergleich zu der allgemeineren Vorhersage des nächsten Tokens, die man in einem LLM erwarten würde, viel geringer ist, hilft uns die Übung dabei, einige der frühen Intuitionen, Untersuchungswerkzeuge und allgemeinen erkenntnistheoretischen Methoden zu erkunden, die typischerweise eingesetzt werden, um herauszufinden, was Modelle tun (und woher wir das wissen).


Die monatlichen ARENA Mechinterp-Herausforderungen hatten großen Einfluss auf diesen Beitrag, und die ersten Aufgaben werden sich daraus ergeben. (Sie sollten sich das Programm unbedingt ansehen.)


Serienstruktur:

  1. Wählen Sie ein Leetcode-Problem als Aufgabe. (Teil 1)
  2. Trainieren Sie darauf ein minimal funktionsfähiges Transformer-Modell. (Teil 2)
  3. Untersuchen Sie, was das Modell gelernt hat. (Teil 3)


Teil 1: Problem

Das Problem mit gültigen Klammern, wie es bei Leetcode auftritt:



Einige modifizierte Einschränkungen des Problems, die wir für die Aufgabe verwenden werden:


  • Die einzigen zulässigen Zeichen sind „(“ und „)“.
    • Dadurch entfällt die Notwendigkeit, Fälle wie "([)]" zu behandeln.


  • Die maximale Eingabesequenz ist 40 Zeichen lang.
    • Um unser Modell für schnelle Iterationen klein zu halten.



Beispiele

„(((())))“ → Gültig

„()()()(“ → Ungültig

“)()()()(” → Ungültig



Vanillelösung

 def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


Hinweise zu Fehlerfällen:


  1. nesting_depth ≠ 0 am Ende der Sequenz

    „()()()((“ → Ungültig


    Aus diesem Grund ist es bis zum Schluss nicht offensichtlich, dass etwas nicht stimmt, wenn wir sehen, dass die zuletzt geöffneten Klammern keine begleitende Klammer haben. Zu beachten ist, dass es bis zum Schluss keinen Punkt in der Sequenz gibt, an dem wir genug Informationen hatten, um zu wissen, dass etwas nicht stimmt.


  2. nesting_depth < 0 an jedem Punkt in der Sequenz

    Beispiel: “())()()(” → Ungültig


    In diesem Fall hingegen liegen an der dritten Stelle genügend Informationen vor, um zu wissen, dass die Gültigkeit der Sequenz nicht wiederherstellbar ist, sodass wir vorzeitig aufgeben können.


    Zu beachten ist, dass dieses Beispiel den ersten Fehlertest bestanden hätte, da die nesting_depth am Ende gleich 0 gewesen wäre. Dieser Testfall hilft uns also nicht nur, frühzeitig aufzuhören, er ist lebenswichtig. Dasselbe gilt für das erste Fehlerbeispiel, bei dem Test 2 bestanden worden wäre.



Nun erwarten wir nicht, dass ein autoregressives Transformer-Modell das Problem auf genau dieselbe Weise löst, da seine Architektur leicht andere Mechanismen bietet als das einmalige Durchlaufen der Sequenz und die Überprüfung, ob alles in Ordnung ist. Wir wissen jedoch mit Sicherheit, dass die Transformer-Architektur (und andere Sequenzverarbeitungsarchitekturen) zumindest in der Lage sind, Informationen über alle Elemente einer Sequenz zu ermitteln und zu verarbeiten . Es ist wichtig, sich daran zu erinnern, dass die Lösung zwar anders aussehen mag, die Struktur des Problems jedoch dieselbe ist und die harten Grenzen dessen, was bekannt ist und wo in der Sequenz sie sich befindet, weiterhin gelten, unabhängig davon, ob es sich um eine Schleife und if-Anweisungen oder ein Ensemble von Self-Attention-Sweeps und MLP-Nichtlinearitäten handelt.


Die interessante Frage ist dann, wie diese Architektur diese Informationen nutzt und ob sie mit den vorhandenen Werkzeugen leicht erkennbar sind. Denn es lässt sich bei einer ausreichend leistungsfähigen Lösung jeder Architektur nicht vermeiden , dass sie nicht zumindest auf die beiden oben genannten Fehlerfälle testet.


Dies ist einer der Vorteile von Spielzeugproblemen. Mit diesen harten Garantien erhalten wir eine ausreichend verständliche, enge Aufgabe, die, wie wir bald sehen werden, als Grundlage für die Untersuchung dienen kann.


Teil 2: Daten und Modell

Vorbereitung der Trainingsdaten

Hier sind einige Zielmerkmale, die wir mit der Datengenerierung anstreben:


  • Eine gleiche Anzahl symmetrischer und unsymmetrischer Saiten.

  • Zeichenfolgen haben eine gerade Länge, da Zeichenfolgen mit ungerader Länge offensichtlich unausgeglichen sind. Für das Modell wäre dies keine sehr interessante Heuristik zum Erlernen.

  • Alle Zeichenfolgenlängen (2–40) sollten gleich wahrscheinlich sein.

  • Bei einer gegebenen Zeichenfolgelänge sollten alle möglichen Verschachtelungstiefen der Klammern gleich wahrscheinlich sein.


Ein gemeinsames Thema ist offensichtlich: Wir versuchen, jede denkbare Verteilungsstatistik mit gleicher Wahrscheinlichkeit zu erstellen, um die Verzerrung in jede beliebige Richtung zu reduzieren, Robustheit zu gewährleisten und offensichtliche Heuristiken für schnelle Erfolge als Option für das Modell auszuschließen. Um Fehlerfälle zu generieren, werden wir zunächst gültige Klammern mit den oben aufgeführten Garantien generieren und dann die Hälfte davon mutieren, um sie unausgeglichen zu machen.


 from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math



 def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


 assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False


Datengenerierungsschema Nr. 1: Random Walk

Der erste Versuch der Klammergenerierung führt lediglich einen Zufallsgang durch. Aber wie Sie in den Diagrammen unten sehen können, ist der Unterraum unausgeglichener Klammern viel größer als der für ausgeglichene; daher müssen wir die Stochastik anders einführen.


 PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens



 random_parens = get_random_walk_parens(1000, (2, 10))



 random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']



 is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]



 fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show() 






Datengenerierungsschema Nr. 2: Greedy Random Nesting Sequence

Wir können die Konstruktion einer Zeichenfolge mit balancierten Klammern in diskrete Einheiten verschachtelter Klammern aufteilen. Für diese gierige Konstruktion wird bei jedem Schritt im Prozess der Zeichenfolgengenerierung eine Verschachtelungstiefe aus einem Korb möglicher Tiefen ausgewählt (um die Zielzeichenfolgenlänge einzuhalten).


Beispielsweise sind für die Ziellänge 6 die folgenden eindeutigen Verschachtelungszerlegungen möglich:


-> [2, 1], [1, 2], [1,1,1] or [3]

Corresponding to:

-> (())(), ()(()), ()()(), ((()))



 def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'



 def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'



 def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10



 def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences



Holen Sie sich ausgeglichene Eltern

 random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]


Schauen wir uns die Häufigkeit der Verschachtelungstiefen an


 depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}



 depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show() 


Verzerrte Tiefenfrequenzen




Und nun sehen wir uns die Längenfrequenzen an.


 paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show() 


Ziemlich flache Saitenlängenfrequenzen


Hinweis zur Datengenerierung

Beachten Sie, dass zwischen den folgenden potenziellen Eigenschaften unserer Datenverteilung eine Spannung besteht.


  1. Jede Zeichenfolgenlänge ist gleich wahrscheinlich.
  2. Jede Teilzeichenfolge der Verschachtelungstiefe ist in allen Zeichenfolgen gleich wahrscheinlich.


Dies liegt daran, dass Untersequenzen mit geringer Verschachtelungstiefe häufiger in einer gegebenen zufälligen Verschachtelungssequenz auftauchen, wie in den Diagrammen oben gezeigt.


Um dieser natürlichen Tendenz einer rein zufälligen Folge entgegenzuwirken, könnten wir bei der Generierung einer gegebenen Teilzeichenfolge aus Klammern eine Stichprobe aus einer Verteilung ziehen, die so verzerrt ist, dass tiefer verschachtelte Werte wahrscheinlicher werden.

Dies wird nach einem ersten Trainingsdurchgang erneut geprüft.


 px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show() 




Erstellen ungleichmäßiger Klammern

Unser Datensatz kann nicht nur ausgeglichene Klammern enthalten. Daher können wir eine Datengenerierungsstrategie entwickeln, um aus unserem ausgeglichenen Datensatz unausgeglichene Zeichenfolgen abzuleiten.


 def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0



 def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))


UnBalanced Parens-Datensatz abrufen


 unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]



Modelltraining

Jetzt haben wir unsere Datensätze und werden aus Spaß unsere Transformer-Architektur von Grund auf neu schreiben.


Zuerst einige Konfigurationen


 @dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1


Dann unser Tokenizer zum Parsen von Eingaben:


 class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)


(Un)Einbettungen


 class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x


Handy Layer Norm


 class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias


Und zum Schluss: Achtung!


 class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)



MLP-Schichten


 class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)



Zusammenbau zu einem Transformator


 class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask


 class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)


Trainingsutensilien


 def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss



Trainingskonfiguration


 cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)


Aktuelles Training


 batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}') 


Trainingssättigung




In Teil 3 untersuchen wir die internen Vorgänge dieses trainierten Netzwerks. Dazu betrachten wir Aufmerksamkeitsmuster und wenden einige der Diagnosetools der mechanistischen Interpretierbarkeit an, wie etwa Aktivierungspatching, um ein mechanistisches Modell zum Verständnis zu erstellen, wie das Netzwerk diese Aufgabe gelöst hat.


Danke, dass Sie bis hierher gelesen haben, und bis bald in Teil 3!