paint-brush
Yüksek Lisans ve Leetcode (Bölüm 1 ve 2): Transformers'ın Algoritmik Sorunlara Çözümlerini Anlamakile@boluben
1,458 okumalar
1,458 okumalar

Yüksek Lisans ve Leetcode (Bölüm 1 ve 2): Transformers'ın Algoritmik Sorunlara Çözümlerini Anlamak

ile Boluwatife Ben-Adeola16m2024/04/16
Read on Terminal Reader

Çok uzun; Okumak

Bu makale serisi, Transformer modellerinin yorumlanabilirliğini derinlemesine inceliyor ve Geçerli Parantez sorununu çözerek algoritmaları nasıl öğrendiklerini araştırıyor. Veri oluşturmayı, model eğitimini kapsıyor ve Bölüm 3'te dikkat kalıplarına ve mekanik anlayışa derinlemesine bir bakış vaat ediyor.
featured image - Yüksek Lisans ve Leetcode (Bölüm 1 ve 2): Transformers'ın Algoritmik Sorunlara Çözümlerini Anlamak
Boluwatife Ben-Adeola HackerNoon profile picture
0-item
1-item

Sinir ağları için Mekanistik yorumlanabilirlik gündemi ruhuna uygun olarak, bu yazı (ve belki bir seri halinde takip edilecek diğerleri), dar bir teknik görevin üstesinden gelmek için bir transformatör modeli tarafından öğrenilen "algoritmaları" ("Geçerli Parantezlerin" değiştirilmiş bir versiyonu) araştırıyor. Leet kodu sorunu.


Görevin faydası, bir Yüksek Lisans'da beklediğiniz daha genel bir sonraki belirteç tahmininden kapsam olarak çok daha mütevazı olsa da, bu alıştırma, tipik olarak kullanılan ilk sezgileri, araştırma araçlarını ve genel epistemolojik metodolojileri keşfetmemize yardımcı olacaktır. modellerin ne yaptığını biliyoruz (ve olduklarını nasıl biliyoruz).


ARENA Mechinterp'in aylık mücadeleleri bu gönderi üzerinde büyük bir etkiye sahipti ve ilk sorunlar dizisi oradan gelecektir. ( Programa mutlaka göz atmalısınız.)


Seri yapısı:

  1. Görev olarak bir Leetcode problemi seçin. (Bölüm 1)
  2. Üzerinde minimum geçerli Transformer modelini eğitin. (Bölüm 2)
  3. Modelin ne öğrendiğini araştırın. (Bölüm 3)


Bölüm 1: Sorun

Leetcode'da görülen Geçerli Parantez sorunu:



Görev için kullanacağımız soruna ilişkin bazı değiştirilmiş kısıtlamalar:


  • Kabul edilebilir karakterler yalnızca “(” ve “)”dir
    • Bu, "([)]" gibi durumların ele alınması ihtiyacını ortadan kaldırır.


  • Maksimum giriş sırası 40 karakter uzunluğundadır.
    • Hızlı yinelemeler için modelimizi küçük tutmaya yardımcı olmak.



Örnekler

“((((())))” → Geçerli

“()()()(” → Geçersiz

“)()()()(” → Geçersiz



Vanilya Çözümü

 def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


Arıza durumlarına ilişkin notlar:


  1. yuvalama_derinliği ≠ 0 dizinin sonunda

    “()()()((” → Geçersiz


    Bunun için son açılan parantezlerin yanında bir parantez olmadığını gördüğümüzde sonuna kadar bir sorun olduğu belli olmuyor. Unutulmaması gereken nokta, dizide en sonuna kadar bir şeylerin yanlış olduğunu bilmek için yeterli bilgiye sahip olduğumuz hiçbir noktanın olmamasıdır.


  2. yuvalama_derinliği < 0 dizinin herhangi bir noktasında

    örnek: “())()()(” → Geçersiz


    Bu durumda ise üçüncü pozisyona göre dizinin geçerliliğinin kurtarılamaz olduğunu bilmek için yeterli bilgi vardır, dolayısıyla erkenden vazgeçebiliriz.


    Unutulmaması gereken bir nokta da, bu örneğin, sondaki nesting_depth 0'a eşit olacağı için ilk başarısızlık testini geçmiş olmasıdır. Yani bu test durumu sadece erken durmamıza yardımcı olmakla kalmıyor, aynı zamanda hayati önem taşıyor. Aynı durum, test 2'yi geçeceği ilk arıza durumu örneği için de geçerlidir.



Şimdi, mimarisinin dizi boyunca bir kez döngü yapıp her şeyin yolunda olup olmadığını kontrol etmekten biraz farklı mekanizmalar sağladığı göz önüne alındığında, bir otoregresif transformatör modelinin sorunu tam olarak aynı şekilde çözmesini beklemiyoruz. Bununla birlikte, transformatör mimarisinin (ve diğer dizi işleme mimarilerinin) en azından bir dizideki tüm öğeler hakkındaki bilgileri keşfedip işleyebildiğinden eminiz. Çözüm farklı görünse de problemin yapısının aynı olduğunu ve ister bir döngü, ister if-ifadeleri veya bir öz topluluğu olsun, bilinenlere ve sıranın neresinde olduğuna dair katı sınırların doğru olmaya devam ettiğini hatırlamak önemlidir. -dikkat taramaları ve MLP doğrusalsızlıkları.


O zaman ilginç olan soru, bu mimarinin bu bilgiyi nasıl kullandığı ve bunun mevcut araçlarla kolayca fark edilip edilemeyeceğidir; çünkü herhangi bir mimarinin yeterince performanslı bir çözümünün en azından yukarıdaki iki arıza durumunu test etmemesi kaçınılmazdır .


Oyuncak problemlerinin avantajlarından biri de budur; Yakında göreceğimiz gibi, soruşturmanın bilgilendirilmesine yardımcı olabilecek bu katı garantilerle, yeterince anlaşılmış dar bir görevle karşı karşıyayız.


Bölüm 2: Veri ve Model

Eğitim Verilerinin Hazırlanması

Veri oluşturmayla hedeflediğimiz bazı hedef özellikler şunlardır:


  • Eşit sayıda dengeli ve dengesiz dizi.

  • Tek uzunluktaki bir dize açıkça dengesiz olduğundan, dizeler eşit uzunlukta olacaktır; bu modelin öğrenmesi için çok ilginç bir buluşsal yöntem olmayacaktır.

  • Tüm dize uzunlukları (2-40) eşit derecede muhtemel olmalıdır.

  • Belirli bir dize uzunluğu için, tüm potansiyel parantezlerin iç içe geçme derinlikleri eşit derecede muhtemel olmalıdır.


Ortak bir tema ortada: Düşünülebilir her dağıtım istatistiğini, herhangi bir yöndeki önyargıyı azaltmak, sağlamlığı sağlamak ve model için bir seçenek olarak bariz hızlı kazanma buluşsal yöntemini reddetmek için eşit derecede muhtemel hale getirmeye çalışıyoruz. Başarısızlık durumlarını oluşturmak için, önce yukarıda listelenen garantilerle geçerli parantezleri oluşturacağız ve ardından bunların yarısını dengesiz hale getirecek şekilde değiştireceğiz.


 from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math



 def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


 assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False


Veri Üretim Şeması #1: Rastgele Yürüyüş

Parantez oluşturmadaki ilk girişim sadece rastgele bir yürüyüş yapar. Ancak aşağıdaki grafikte görebileceğiniz gibi, dengesiz parantezlerin alt uzayı, dengeli olanlardan çok daha büyüktür; bu yüzden stokastisiteyi farklı bir şekilde tanıtmamız gerekecek.


 PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens



 random_parens = get_random_walk_parens(1000, (2, 10))



 random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']



 is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]



 fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show() 






Veri Üretim Şeması #2: Açgözlü Rastgele Yerleştirme Dizisi

Dengeli bir parantez dizisinin yapısını iç içe geçmiş parantezlerin ayrı birimlerine ayırabiliriz. Bu açgözlü yapı için, bir tel oluşturma sürecinin her adımında, geçerli derinliklerden oluşan bir sepetten bir yuvalama derinliği seçilir (hedef telin uzunluğuna saygı göstermek için).


Örneğin hedef uzunluğu 6 için aşağıdaki benzersiz yuvalama ayrıştırmaları mümkündür:


-> [2, 1], [1, 2], [1,1,1] or [3]

Corresponding to:

-> (())(), ()(()), ()()(), ((()))



 def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'



 def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'



 def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10



 def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences



Dengeli Ebeveynler Alın

 random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]


Yuvalama derinliklerinin frekanslarını görelim


 depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}



 depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show() 


Çarpık derinlik frekansları




Şimdi uzunluk frekanslarını görelim.


 paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show() 


Oldukça düz dizi uzunluğu frekansları


Veri Oluşturma Notu

Veri dağıtımımızın aşağıdaki potansiyel özellikleri arasında bir gerilim olduğunu unutmayın.


  1. Her dize uzunluğu eşit derecede olasıdır.
  2. Her yuvalama derinliği alt dizisi, tüm dizelerde eşit derecede olasıdır.


Bunun nedeni, düşük yuvalama derinliğine sahip alt dizilerin, yukarıdaki grafikte gösterildiği gibi, belirli bir rastgele yuvalama dizisinde ortaya çıkma konusunda daha fazla fırsata sahip olmasıdır.


Tamamen rastgele dizilimin bu doğal eğilimine karşı koymak için, parantezlerin belirli bir altdizisini oluştururken, daha derin yuva değerlerini daha olası hale getirmek için çarpık bir dağılımdan örnek alabiliriz.

Bu, antrenmandaki ilk geçişten sonra tekrar ele alınacaktır.


 px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show() 




Dengesiz Parantez Oluşturma

Veri setimiz yalnızca dengeli parantezlerden oluşamaz. Böylece dengeli veri kümemizden dengesiz dizeler türetmek için bir veri oluşturma stratejisi oluşturabiliriz.


 def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0



 def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))


Dengesiz Ebeveyn Veri Kümesini Alın


 unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]



Model Eğitimi

Artık veri setlerimiz elimizde, eğlence olsun diye Transformer mimarimizi sıfırdan yazacağız.


İlk önce bazı yapılandırmalar


 @dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1


Daha sonra girdileri ayrıştırmak için belirtecimiz:


 class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)


(Un)Gömmeler


 class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x


Kullanışlı Katman Normu


 class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias


Ve son olarak Dikkat!


 class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)



MLP Katmanları


 class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)



Bunu bir Transformatörde bir araya getirmek


 class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask


 class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)


Eğitim araçları


 def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss



Eğitim yapılandırması


 cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)


Gerçek Eğitim


 batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}') 


Eğitim Doyurucu




3. Bölümde bu eğitimli ağın iç kısımlarını inceleyeceğiz. Bunu, dikkat modellerine bakarak ve ağın bu görevi nasıl çözdüğünü anlamak için mekanik bir model oluşturmak üzere aktivasyon yaması gibi Mekanik yorumlanabilirliğin bazı teşhis araçlarını uygulayarak yapacağız.


Buraya kadar okuduğunuz için teşekkürler. Yakında Bölüm 3'te görüşürüz!