paint-brush
LLM và Leetcode (Phần 1 & 2): Tìm hiểu giải pháp của Transformers cho các vấn đề thuật toántừ tác giả@boluben
20,299 lượt đọc
20,299 lượt đọc

LLM và Leetcode (Phần 1 & 2): Tìm hiểu giải pháp của Transformers cho các vấn đề thuật toán

từ tác giả Bolu Ben-Adeola16m2024/04/16
Read on Terminal Reader

dài quá đọc không nổi

Loạt bài viết này đi sâu vào khả năng diễn giải của các mô hình Transformer, nghiên cứu cách chúng học các thuật toán bằng cách giải quyết vấn đề Dấu ngoặc đơn hợp lệ. Nó bao gồm việc tạo dữ liệu, đào tạo mô hình và hứa hẹn mang đến cái nhìn sâu sắc về các kiểu chú ý và hiểu biết cơ học trong Phần 3.
featured image - LLM và Leetcode (Phần 1 & 2): Tìm hiểu giải pháp của Transformers cho các vấn đề thuật toán
Bolu Ben-Adeola HackerNoon profile picture
0-item
1-item

Theo tinh thần của chương trình nghị sự về khả năng diễn giải Cơ học cho mạng lưới thần kinh , bài đăng này (và có lẽ những bài khác sẽ tiếp theo trong loạt bài này) nghiên cứu các “thuật toán” mà mô hình máy biến áp đã học để giải quyết một nhiệm vụ kỹ thuật hẹp — một phiên bản sửa đổi của “Dấu ngoặc đơn hợp lệ” Vấn đề về Leetcode.


Mặc dù tiện ích của nhiệm vụ có phạm vi khiêm tốn hơn nhiều so với dự đoán mã thông báo tiếp theo tổng quát hơn mà bạn mong đợi trong LLM, nhưng bài tập này sẽ giúp chúng tôi khám phá một số trực giác ban đầu, công cụ điều tra và các phương pháp nhận thức luận chung thường được triển khai để biết những mô hình đang làm gì (và làm sao chúng ta biết được chúng.)


Các thử thách hàng tháng của ARENA Mechinterp có ảnh hưởng rất lớn đến bài đăng này và loạt vấn đề đầu tiên sẽ xuất phát từ đó. (Bạn chắc chắn nên kiểm tra chương trình .)


Cấu trúc chuỗi:

  1. Chọn một vấn đề Leetcode làm nhiệm vụ. (Phần 1)
  2. Huấn luyện một mô hình Transformer khả thi tối thiểu trên đó. (Phần 2)
  3. Điều tra những gì mô hình đã học được. (Phần 3)


Phần 1: Vấn đề

Vấn đề về Dấu ngoặc đơn hợp lệ như đã thấy trên Leetcode:



Một số ràng buộc đã sửa đổi đối với vấn đề mà chúng tôi sẽ sử dụng cho nhiệm vụ:


  • Các ký tự duy nhất được chấp nhận là “(” và “)”
    • Điều này loại bỏ nhu cầu xử lý các trường hợp như "([)]”.


  • Chuỗi đầu vào tối đa dài 40 ký tự.
    • Để giúp giữ cho mô hình của chúng tôi nhỏ để lặp lại nhanh chóng.



Ví dụ

“(((())))” → Hợp lệ

“()()()(” → Không hợp lệ

“)()()()(” → Không hợp lệ



Dung dịch vani

 def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


Lưu ý về các trường hợp thất bại:


  1. lồng_độ sâu ≠ 0 ở cuối chuỗi

    “()()()((” → Không hợp lệ


    Đối với điều này, rõ ràng là không có gì sai cho đến cuối cùng khi chúng tôi thấy rằng dấu ngoặc được mở cuối cùng không có dấu ngoặc đi kèm. Điều cần lưu ý là không có điểm nào trong trình tự cho đến cuối cùng mà chúng ta có đủ thông tin để biết có điều gì đó không ổn.


  2. Nesting_Deep < 0 tại bất kỳ điểm nào trong chuỗi

    ví dụ: “())()()(” → Không hợp lệ


    Mặt khác, trong trường hợp này, vị trí thứ ba có đủ thông tin để biết tính hợp lệ của chuỗi là không thể khôi phục được, vì vậy chúng ta có thể gọi nó là kết thúc sớm.


    Điều cần lưu ý là ví dụ này đã vượt qua thử nghiệm thất bại đầu tiên vì nesting_depth ở cuối sẽ bằng 0. Vì vậy, trường hợp thử nghiệm này không chỉ giúp chúng tôi dừng sớm mà nó còn quan trọng. Điều tương tự cũng áp dụng cho ví dụ về trường hợp lỗi đầu tiên mà lẽ ra nó đã vượt qua bài kiểm tra 2.



Hiện tại, chúng tôi không mong đợi một mô hình máy biến áp tự hồi quy sẽ giải quyết vấn đề theo cách giống hệt nhau, do kiến trúc của nó có các cơ chế hơi khác so với việc lặp lại trình tự một lần và kiểm tra xem tất cả có ổn không. Tuy nhiên, chúng tôi biết chắc chắn rằng kiến trúc máy biến áp (và các kiến trúc xử lý trình tự khác) ít nhất có khả năng khám pháxử lý thông tin về tất cả các phần tử trong một trình tự. Điều quan trọng cần nhớ là mặc dù giải pháp có thể trông khác nhau nhưng cấu trúc của vấn đề vẫn giống nhau và các giới hạn cố định về những gì đã biết và vị trí trong chuỗi vẫn tiếp tục đúng cho dù đó là một vòng lặp và các câu lệnh if hay một tập hợp của chính nó. - quét chú ý và phi tuyến tính MLP.


Câu hỏi thú vị sau đó là làm thế nào kiến trúc này tận dụng thông tin này và liệu nó có dễ dàng được nhận thấy bằng công cụ hiện có hay không; bởi vì không thể tránh khỏi việc một giải pháp có đủ hiệu suất của bất kỳ kiến trúc nào lại không kiểm tra được ít nhất hai trường hợp lỗi nêu trên.


Đây là một trong những ưu điểm của bài toán đồ chơi; chúng ta có được một nhiệm vụ hẹp được hiểu đầy đủ với những đảm bảo cứng rắn này có thể giúp cung cấp thông tin cho cuộc điều tra như chúng ta sẽ sớm thấy.


Phần 2: Dữ liệu & Mô hình

Chuẩn bị dữ liệu đào tạo

Dưới đây là một số đặc điểm mục tiêu mà chúng tôi sẽ hướng tới khi tạo dữ liệu:


  • Số lượng dây cân bằng và không cân bằng bằng nhau.

  • Các chuỗi sẽ có độ dài chẵn, vì một chuỗi có độ dài lẻ rõ ràng là không cân bằng; đây sẽ không phải là một phương pháp phỏng đoán thú vị cho mô hình học hỏi.

  • Tất cả độ dài chuỗi (2-40) đều có khả năng như nhau.

  • Đối với độ dài chuỗi nhất định, tất cả độ sâu lồng nhau của dấu ngoặc đơn tiềm năng sẽ có khả năng như nhau.


Một chủ đề chung rất rõ ràng: chúng tôi đang cố gắng làm cho mọi thống kê phân phối có thể suy nghĩ được đều có khả năng giảm thiểu sự thiên vị theo bất kỳ hướng nào, để đảm bảo tính chắc chắn và từ chối các phương pháp phỏng đoán thắng nhanh rõ ràng như một lựa chọn cho mô hình. Để tạo ra các trường hợp lỗi, trước tiên chúng tôi sẽ tạo các dấu ngoặc đơn hợp lệ với các đảm bảo được liệt kê ở trên, sau đó thay đổi một nửa trong số chúng để trở nên mất cân bằng.


 from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math



 def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


 assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False


Sơ đồ tạo dữ liệu số 1: Đi bộ ngẫu nhiên

Lần thử đầu tiên trong việc tạo dấu ngoặc đơn chỉ thực hiện một bước đi ngẫu nhiên. Nhưng như bạn có thể thấy trong đồ thị bên dưới, không gian con của dấu ngoặc đơn không cân bằng lớn hơn nhiều so với không gian con của dấu ngoặc đơn cân bằng; vì vậy chúng ta sẽ phải giới thiệu tính ngẫu nhiên theo cách khác.


 PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens



 random_parens = get_random_walk_parens(1000, (2, 10))



 random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']



 is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]



 fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show() 






Lược đồ tạo dữ liệu số 2: Trình tự lồng ngẫu nhiên tham lam

Chúng ta có thể chia nhỏ việc xây dựng một chuỗi dấu ngoặc đơn cân bằng thành các đơn vị riêng biệt gồm các dấu ngoặc đơn lồng nhau. Đối với cấu trúc tham lam này, ở mỗi bước trong quá trình tạo chuỗi, độ sâu lồng nhau được chọn từ một nhóm có độ sâu khả thi (để tôn trọng độ dài chuỗi mục tiêu.)


Ví dụ: đối với độ dài mục tiêu 6 , có thể phân tách lồng nhau duy nhất sau đây:


-> [2, 1], [1, 2], [1,1,1] or [3]

Corresponding to:

-> (())(), ()(()), ()()(), ((()))



 def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'



 def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'



 def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10



 def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences



Nhận Parens cân bằng

 random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]


Chúng ta hãy xem tần số của độ sâu làm tổ


 depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}



 depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show() 


Tần số độ sâu lệch




Và bây giờ, để xem tần số độ dài.


 paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show() 


Tần số chiều dài chuỗi khá phẳng


Ghi chú tạo dữ liệu

Lưu ý rằng có sự căng thẳng giữa các thuộc tính tiềm năng sau trong quá trình phân phối dữ liệu của chúng tôi.


  1. Mọi độ dài chuỗi đều có khả năng như nhau.
  2. Mọi chuỗi con có độ sâu lồng nhau đều có khả năng như nhau trên tất cả các chuỗi.


Điều này là do các chuỗi con có độ sâu lồng ghép thấp sẽ có nhiều cơ hội xuất hiện trong một chuỗi lồng ghép ngẫu nhiên nhất định, như được hiển thị trong các sơ đồ ở trên.


Để chống lại xu hướng tự nhiên này của chuỗi hoàn toàn ngẫu nhiên, khi tạo ra một chuỗi con dấu ngoặc đơn nhất định, chúng ta có thể lấy mẫu từ một phân bố bị lệch để tạo ra các giá trị lồng sâu hơn có nhiều khả năng xảy ra hơn.

Điều này sẽ được xem xét lại sau lần đầu tiên vượt qua khóa huấn luyện.


 px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show() 




Tạo dấu ngoặc đơn không cân bằng

Tập dữ liệu của chúng tôi không thể chỉ có dấu ngoặc đơn cân bằng. Vì vậy, chúng tôi có thể tạo chiến lược tạo dữ liệu để lấy các chuỗi không cân bằng từ tập dữ liệu cân bằng của mình.


 def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0



 def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))


Nhận bộ dữ liệu Parens không cân bằng


 unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]



Đào tạo người mẫu

Bây giờ chúng ta đã có các tập dữ liệu, để giải trí, chúng ta sẽ viết kiến trúc Transformer từ đầu.


Đầu tiên một số cấu hình


 @dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1


Sau đó, mã thông báo của chúng tôi để phân tích cú pháp đầu vào:


 class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)


(Un) Nhúng


 class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x


Định mức lớp tiện dụng


 class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias


Và cuối cùng là Chú ý!


 class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)



Lớp MLP


 class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)



Đặt nó lại với nhau thành một máy biến áp


 class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask


 class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)


Tiện ích đào tạo


 def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss



Cấu hình luyện tập


 cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)


Đào tạo thực tế


 batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}') 


Đào tạo bão hòa




Trong Phần 3, chúng ta sẽ điều tra nội bộ của mạng được đào tạo này. Chúng tôi sẽ thực hiện điều này bằng cách xem xét các mẫu chú ý và áp dụng một số công cụ chẩn đoán về khả năng diễn giải Cơ học, chẳng hạn như bản vá kích hoạt để xây dựng mô hình cơ học nhằm hiểu cách mạng đã giải quyết nhiệm vụ này.


Cảm ơn bạn đã đọc đến đây và sẽ sớm gặp bạn ở Phần 3!