paint-brush
एलएलएम बनाम लीटकोड (भाग 1 और 2): एल्गोरिदमिक समस्याओं के लिए ट्रांसफॉर्मर्स के समाधान को समझनाद्वारा@boluben
20,299 रीडिंग
20,299 रीडिंग

एलएलएम बनाम लीटकोड (भाग 1 और 2): एल्गोरिदमिक समस्याओं के लिए ट्रांसफॉर्मर्स के समाधान को समझना

द्वारा Bolu Ben-Adeola16m2024/04/16
Read on Terminal Reader

बहुत लंबा; पढ़ने के लिए

यह लेख श्रृंखला ट्रांसफॉर्मर मॉडल की व्याख्यात्मकता पर गहराई से चर्चा करती है, यह जांचती है कि वे वैध कोष्ठक समस्या से निपटने के द्वारा एल्गोरिदम कैसे सीखते हैं। यह डेटा निर्माण, मॉडल प्रशिक्षण को कवर करता है, और भाग 3 में ध्यान पैटर्न और यांत्रिक समझ पर गहराई से नज़र डालने का वादा करता है।
featured image - एलएलएम बनाम लीटकोड (भाग 1 और 2): एल्गोरिदमिक समस्याओं के लिए ट्रांसफॉर्मर्स के समाधान को समझना
Bolu Ben-Adeola HackerNoon profile picture
0-item
1-item

तंत्रिका नेटवर्क के लिए यंत्रवत व्याख्यात्मकता एजेंडे की भावना में, यह पोस्ट (और शायद श्रृंखला में आने वाले अन्य) एक संकीर्ण तकनीकी कार्य से निपटने के लिए एक ट्रांसफार्मर मॉडल द्वारा सीखे गए "एल्गोरिदम" की जांच करता है - "वैध कोष्ठक" लीटकोड समस्या का एक संशोधित संस्करण।


यद्यपि इस कार्य की उपयोगिता, एलएलएम में अपेक्षित अधिक सामान्य नेक्स्ट-टोकन भविष्यवाणी की तुलना में बहुत अधिक मामूली है, फिर भी यह अभ्यास हमें कुछ प्रारंभिक अंतर्ज्ञानों, जांच उपकरणों और सामान्य ज्ञान-मीमांसा पद्धतियों का पता लगाने में मदद करेगा, जिनका उपयोग आम तौर पर यह जानने के लिए किया जाता है कि मॉडल क्या कर रहे हैं (और हम कैसे जानते हैं कि वे क्या कर रहे हैं।)


ARENA Mechinterp मासिक चुनौतियों का इस पोस्ट पर बहुत बड़ा प्रभाव था, और समस्याओं का पहला सेट वहाँ से आएगा। (आपको निश्चित रूप से कार्यक्रम की जाँच करनी चाहिए।)


श्रृंखला संरचना:

  1. एक कार्य के रूप में एक लीटकोड समस्या चुनें। (भाग 1)
  2. इस पर न्यूनतम व्यवहार्य ट्रांसफॉर्मर मॉडल का प्रशिक्षण करें। (भाग 2)
  3. मॉडल ने क्या सीखा, इसकी जांच करें। (भाग 3)


भाग 1: समस्या

लीटकोड पर देखी गई वैध कोष्ठक समस्या:



समस्या पर कुछ संशोधित प्रतिबंध जिन्हें हम कार्य के लिए उपयोग करेंगे:


  • एकमात्र स्वीकार्य वर्ण “(” और “)” हैं
    • इससे "([)]" जैसे मामलों को संभालने की आवश्यकता समाप्त हो जाती है।


  • अधिकतम इनपुट अनुक्रम 40 वर्ण लम्बा है।
    • त्वरित पुनरावृत्तियों के लिए हमारे मॉडल को छोटा रखने में सहायता करना।



उदाहरण

“(((())))” → मान्य

“()()()(” → अमान्य

“)()()()(” → अमान्य



वेनिला समाधान

 def isValid(self, s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


विफलता के मामलों पर नोट्स:


  1. अनुक्रम के अंत में nesting_depth ≠ 0

    “()()()((” → अमान्य


    इसके लिए, यह स्पष्ट नहीं है कि कुछ भी गलत है, जब तक कि हम अंत में नहीं देखते कि अंतिम खुले कोष्ठक में कोई संगत कोष्ठक नहीं है। ध्यान देने वाली बात यह है कि अनुक्रम में ऐसा कोई बिंदु नहीं है, जब तक कि अंत तक हमें यह जानने के लिए पर्याप्त जानकारी न हो कि कुछ गड़बड़ है।


  2. अनुक्रम में किसी भी बिंदु पर nesting_depth < 0

    उदाहरण: “())()()(” → अमान्य


    दूसरी ओर, इस मामले में, तीसरी स्थिति तक यह जानने के लिए पर्याप्त जानकारी है कि अनुक्रम की वैधता अप्राप्य है, इसलिए हम इसे जल्दी ही समाप्त कर सकते हैं।


    ध्यान देने वाली बात यह है कि यह उदाहरण पहला विफलता परीक्षण पास कर लेता क्योंकि अंत में nesting_depth 0 के बराबर होता। इसलिए यह परीक्षण मामला हमें न केवल जल्दी रोकने में मदद करता है, बल्कि यह महत्वपूर्ण भी है। यही बात पहले विफलता मामले के उदाहरण पर भी लागू होती है जहाँ यह परीक्षण 2 पास कर लेता।



अब, हम उम्मीद नहीं करते कि ऑटोरिग्रैसिव ट्रांसफॉर्मर मॉडल समस्या को ठीक उसी तरह हल करेगा, क्योंकि इसकी वास्तुकला अनुक्रम के माध्यम से एक बार लूप करने और यह जाँचने की तुलना में थोड़ा अलग तंत्र प्रदान करती है कि क्या सब ठीक है। हालाँकि, हम यह सुनिश्चित करने के लिए जानते हैं कि ट्रांसफॉर्मर आर्किटेक्चर (और अन्य अनुक्रम प्रसंस्करण आर्किटेक्चर) कम से कम एक अनुक्रम में सभी तत्वों के बारे में जानकारी की खोज और प्रक्रिया करने में सक्षम हैं। यह याद रखना महत्वपूर्ण है कि जबकि समाधान अलग दिख सकता है, समस्या की संरचना समान है और जो ज्ञात है और अनुक्रम में कहाँ है, इस पर कठोर सीमाएँ सत्य बनी रहती हैं चाहे वह लूप और अगर-कथन हो या स्व-ध्यान स्वीप और एमएलपी नॉनलाइनियरिटी का एक समूह हो।


दिलचस्प सवाल यह है कि यह आर्किटेक्चर इस जानकारी का लाभ कैसे उठाता है और क्या यह मौजूदा टूलिंग के साथ आसानी से पहचाना जा सकता है; क्योंकि किसी भी आर्किटेक्चर के पर्याप्त प्रदर्शन वाले समाधान के लिए कम से कम उपरोक्त दो विफलता मामलों के लिए परीक्षण न करना अपरिहार्य है।


यह खिलौना समस्याओं के लाभों में से एक है; इन कठोर गारंटियों के साथ हमें पर्याप्त रूप से समझा जाने वाला संकीर्ण कार्य मिलता है, जो जांच को सूचित करने में मदद कर सकता है, जैसा कि हम जल्द ही देखेंगे।


भाग 2: डेटा और मॉडल

प्रशिक्षण डेटा तैयारी

यहां कुछ लक्ष्य विशेषताएं दी गई हैं जिन्हें हम डेटा उत्पादन के लिए अपना रहे हैं:


  • संतुलित और असंतुलित तारों की समान संख्या।

  • स्ट्रिंग सम लंबाई की होगी, क्योंकि विषम लंबाई वाली स्ट्रिंग स्पष्ट रूप से असंतुलित होती है; जो मॉडल के लिए सीखने के लिए बहुत दिलचस्प अनुमान नहीं होगा।

  • सभी स्ट्रिंग की लम्बाई (2-40) समान रूप से संभावित होनी चाहिए।

  • किसी दी गई स्ट्रिंग लंबाई के लिए, सभी संभावित कोष्ठकों की नेस्टिंग गहराई समान रूप से संभावित होनी चाहिए।


एक सामान्य विषय स्पष्ट है: हम हर विचारणीय वितरण सांख्यिकी को किसी भी दिशा में पूर्वाग्रह को कम करने, मजबूती सुनिश्चित करने और मॉडल के लिए एक विकल्प के रूप में स्पष्ट त्वरित-जीत अनुमानों को अस्वीकार करने के लिए समान रूप से संभावित बनाने की कोशिश कर रहे हैं। विफलता के मामलों को उत्पन्न करने के लिए, हम पहले ऊपर सूचीबद्ध गारंटी के साथ वैध कोष्ठक उत्पन्न करेंगे और फिर असंतुलित होने के लिए उनमें से आधे को उत्परिवर्तित करेंगे।


 from random import randint, randrange, sample from typing import List, Tuple, Union, Optional, Callable, Dict from jaxtyping import Float, Int import torch as t from torch import Tensor import plotly.express as px import einops from dataclasses import dataclass import math



 def isValid(s: str) -> bool: nesting_depth = 0 for bracket in s: if bracket == '(': # An opening bracket increases unresolved nesting depth nesting_depth += 1 elif bracket == ')': # A closing bracket decreases unresolved nesting depth nesting_depth -= 1 # We don't expect to ever have negative unresolved nesting depth, # so we can declare 'invalid' midway through the sequence if we see this if nesting_depth < 0: return False # Final check that all open brackets were closed. return nesting_depth == 0


 assert isValid('()()((((()())())))') == True assert isValid(')()((((()())()))(') == False


डेटा जनरेशन स्कीम #1: रैंडम वॉक

कोष्ठक निर्माण का पहला प्रयास केवल एक यादृच्छिक चाल है। लेकिन जैसा कि आप नीचे दिए गए ग्राफ़ में देख सकते हैं कि असंतुलित कोष्ठकों का उप-स्थान संतुलित कोष्ठकों की तुलना में बहुत बड़ा है; इसलिए हमें स्टोकैस्टिसिटी को अलग तरीके से पेश करना होगा।


 PARENS = ['(', ')'] def get_random_walk_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: range_start, range_end = length_range random_parens = [ # Add 1 to make passed range_end inclusive ''.join(PARENS[randint(0, 1)] for _ in range(randrange(range_start, range_end + 1, 2))) for _ in range(parens_num) ] return random_parens



 random_parens = get_random_walk_parens(1000, (2, 10))



 random_parens[:10] # output [')(', '(((())()', ')(((()()))', '))))))', '))())()(', '))', '(())', ')()(()()()', ')()())))((', '()']



 is_valid_evals = [str(isValid(random_paren)) for random_paren in random_parens] len_evals = [len(random_paren) for random_paren in random_parens]



 fig = px.histogram(is_valid_evals, title="Count of is-balanced for random walk parentheses strings") fig.show() 






डेटा जनरेशन योजना #2: लालची यादृच्छिक नेस्टिंग अनुक्रम

हम संतुलित कोष्ठक स्ट्रिंग के निर्माण को नेस्टेड कोष्ठकों की असतत इकाइयों में विभाजित कर सकते हैं। इस लालची निर्माण के लिए, स्ट्रिंग बनाने की प्रक्रिया में प्रत्येक चरण में व्यवहार्य गहराई की एक टोकरी से एक नेस्टिंग गहराई चुनी जाती है (लक्ष्य स्ट्रिंग लंबाई का सम्मान करने के लिए।)


उदाहरण के लिए लक्ष्य लंबाई 6 के लिए, निम्नलिखित अद्वितीय नेस्टिंग विघटन संभव हैं:


-> [2, 1], [1, 2], [1,1,1] or [3]

Corresponding to:

-> (())(), ()(()), ()()(), ((()))



 def get_balanced_parens(nest_depth: int) -> str: """Generate parentheses at the required nesting depth.""" return (PARENS[0] * nest_depth) + (PARENS[1] * nest_depth) assert get_balanced_parens(3) == '((()))'



 def get_balanced_sequence_parens(nest_depth_sequence: List[int]) -> str: """Return a parentheses string following the nesting depth sequence from a given list.""" return ''.join(get_balanced_parens(nest_depth) for nest_depth in nest_depth_sequence) assert get_balanced_sequence_parens([1,1,2,3]) == '()()(())((()))'



 def get_random_depth_sequence(target_paren_len: int) -> List[int]: depth_sequence = [] while target_paren_len > 0: depth = randint(1, target_paren_len / 2) depth_sequence.append(depth) target_paren_len -= 2 * depth return depth_sequence rand_depth_seq = get_random_depth_sequence(10) print(rand_depth_seq) # Example output: '[3, 1, 1]' assert sum([2 * depth for depth in rand_depth_seq]) == 10



 def get_random_sequence_parens(parens_num: int, length_range: Tuple[int]) -> List[str]: random_depth_sequences = [get_random_depth_sequence( randrange(*length_range, 2) ) for _ in range(parens_num)] random_parens = [ get_balanced_sequence_parens(random_depth_sequence) for random_depth_sequence in random_depth_sequences ] return random_parens, random_depth_sequences



संतुलित माता-पिता पाएं

 random_seq_parens, depth_sequences = get_random_sequence_parens(100000, (2, 11)) is_valid_evals = [str(isValid(random_paren)) for random_paren in random_seq_parens] len_evals = [len(random_paren) for random_paren in random_seq_parens]


आइए नेस्टिंग गहराई की आवृत्तियों को देखें


 depth_freq = {} for seq in depth_sequences: for depth in seq: depth_freq.setdefault(depth, 0) depth_freq[depth] += 1 depth_freq # output -> {2: 39814, 1: 100088, 3: 20127, 4: 9908, 5: 4012}



 depth_seq_hist = px.histogram(depth_sequences, title="Frequence of nesting depths in 'Random Nesting Depth Sequence' Output") depth_seq_hist.show() 


तिरछी गहराई आवृत्तियाँ




और अब, लम्बाई आवृत्तियों को देखें।


 paren_len_hist = px.histogram(len_evals, title="Frequency of string lengths") paren_len_hist.show() 


काफी हद तक समतल स्ट्रिंग-लंबाई आवृत्तियाँ


डेटा जनरेशन नोट

ध्यान दें कि हमारे डेटा वितरण के निम्नलिखित संभावित गुणों के बीच तनाव है।


  1. प्रत्येक स्ट्रिंग की लम्बाई समान रूप से संभावित है।
  2. प्रत्येक नेस्टिंग गहराई सबस्ट्रिंग सभी स्ट्रिंग्स में समान रूप से संभावित होती है।


ऐसा इसलिए है क्योंकि कम नेस्टिंग-गहराई वाले उप-अनुक्रमों को दिए गए यादृच्छिक नेस्टिंग अनुक्रम में दिखने के अधिक अवसर मिलेंगे, जैसा कि ऊपर दिए गए प्लॉट में दिखाया गया है।


विशुद्ध यादृच्छिक अनुक्रम की इस प्राकृतिक प्रवृत्ति का मुकाबला करने के लिए, कोष्ठकों की एक दी गई उप-स्ट्रिंग उत्पन्न करते समय, हम गहरे नेस्ट मानों को अधिक संभावित बनाने के लिए तिरछे वितरण से नमूना ले सकते हैं।

प्रशिक्षण के प्रथम चरण के बाद इस पर पुनः विचार किया जाएगा।


 px.histogram(random_seq_parens, title="Frequency of balanced Parentheses").show() 




असंतुलित कोष्ठक बनाना

हमारे डेटासेट में केवल संतुलित कोष्ठक नहीं हो सकते। इसलिए हम अपने संतुलित डेटासेट से असंतुलित स्ट्रिंग प्राप्त करने के लिए डेटा जनरेशन रणनीति बना सकते हैं।


 def _flip_idx(idx): return (idx + 1) % 2 assert _flip_idx(0) == 1 assert _flip_idx(1) == 0



 def make_parens_unbalanced(paren: str) -> str: """Take balanced-parentheses and randomly mutate it till it's unbalanced. Both the number of mutations and indices are chosen at random. """ paren_idx_dict = {'(': 0, ')': 1} paren_list = list(paren) num_flipped_positions = randint(1, len(paren)) while isValid(''.join(paren_list)): flip_points = sample(range(len(paren)), num_flipped_positions) for flip_idx in flip_points: idx_char = paren_idx_dict[paren_list[flip_idx]] flipped_idx = _flip_idx(idx_char) paren_list[flip_idx] = PARENS[flipped_idx] return ''.join(paren_list) assert not isValid(make_parens_unbalanced('((()))'))


असंतुलित पैरेंस डेटासेट प्राप्त करें


 unbal_random_seq_parens = [make_parens_unbalanced(paren) for paren in random_seq_parens]



मॉडल प्रशिक्षण

अब हमारे पास डेटासेट हैं, और मजे के लिए, हम अपना ट्रांसफॉर्मर आर्किटेक्चर शुरू से लिखने जा रहे हैं।


पहले कुछ कॉन्फ़िगरेशन


 @dataclass class Config: context_len = 12 d_vocab: int = 5 d_out_vocab: int = 2 d_model: int = 56 d_head = 28 d_mlp = 56 * 4 causal_attention = False num_heads = 2 num_layers = 3 init_range: float = 1 PAD_TOKEN_IDX = 1


फिर इनपुट पार्स करने के लिए हमारा टोकेनाइज़र:


 class Tokenizer: def __init__(self, vocab: str, context_width: Int, enforce_context: bool=False): self.START_TOKEN, START_TOKEN_IDX = "<start>", 0 self.PAD_TOKEN, PAD_TOKEN_IDX = "<pad>", 1 self.END_TOKEN, END_TOKEN_IDX = "<end>", 2 util_tokens_t_to_i = {self.START_TOKEN: START_TOKEN_IDX, self.PAD_TOKEN: PAD_TOKEN_IDX, self.END_TOKEN: END_TOKEN_IDX} util_tokens_i_to_t = {START_TOKEN_IDX: self.START_TOKEN, PAD_TOKEN_IDX: self.PAD_TOKEN, END_TOKEN_IDX: self.END_TOKEN} self.enforce_context = enforce_context self.context_width = context_width self.vocab = vocab self.t_to_i = {**util_tokens_t_to_i, **{token: token_id + 3 for token_id, token in enumerate(self.vocab)}} self.i_to_t = {**util_tokens_i_to_t, **{token_id + 3: token for token_id, token in enumerate(self.vocab)}} @staticmethod def pad_sequence(sequence: str, end_token: str, pad_token: str, max_length: Int, enforce_context: bool) -> List[str]: if not enforce_context: # Truncate if sequence length is greater sequence = sequence[:max_length] else: assert len(sequence) <= max_length, f"Sequence length is greater than the max allowed data length: {max_length}" return list(sequence) + [end_token] + [pad_token] * (max_length - len(sequence)) def tokenize(self, data: Union[str, List[str]]) -> Int[Tensor, "batch seq"]: if isinstance(data, str): data = [data] def _list_tokens_to_id(tokens: List[str]) -> List[Int]: return [self.t_to_i[token] for token in tokens] # to leave room for start and end tokens max_seq_len = self.context_width - 2 data_as_tokens = [ _list_tokens_to_id([ self.START_TOKEN, *self.pad_sequence(seq, self.END_TOKEN, self.PAD_TOKEN, max_seq_len, self.enforce_context), ]) for seq in data ] return t.tensor(data_as_tokens)


(अन)एम्बेडिंग


 class EmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_E = t.nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model)) t.nn.init.normal_(self.W_E, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq"]) -> Int[Tensor, "batch seq d_model"]: return self.W_E[x] class UnEmbedLayer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.W_U = t.nn.Parameter(t.empty(cfg.d_model, cfg.d_out_vocab)) t.nn.init.normal_(self.W_U, mean=0.0, std=cfg.init_range) def forward(self, x: Int[Tensor, "batch seq d_model"]) -> Int[Tensor, "batch seq d_out_vocab"]: return x @ self.W_U class PositionalEmbedding(t.nn.Module): def __init__(self, cfg: Config): super().__init__() denom = t.exp( t.arange(0, cfg.d_model, 2) * -(math.log(10000.0) / cfg.d_model) ) pos = t.arange(0, cfg.context_len).unsqueeze(1) param = pos * denom P_E = t.zeros(cfg.context_len, cfg.d_model) P_E[:, 0::2] = t.sin(param) P_E[:, 1::2] = t.cos(param) P_E = P_E.unsqueeze(0) self.register_buffer("P_E", P_E) def forward(self, x): _batch, seq_len, d_model = x.shape x = x + self.P_E[..., :seq_len, :d_model].requires_grad_(False) return x


हैंडी लेयर नॉर्म


 class LayerNorm(t.nn.Module): def __init__(self, cfg): super().__init__() self.scale = t.nn.Parameter(t.ones(cfg.d_model)) self.bias = t.nn.Parameter(t.zeros(cfg.d_model)) def forward(self, x): mean = t.mean(x, dim=2, keepdim=True) var = t.var(x, dim=2, keepdim=True, unbiased=False) y = (x - mean) / (var + 0.00001).sqrt() return (y * self.scale) + self.bias


और अंत में ध्यान दें!


 class AttentionLayer(t.nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32)) self.cfg = cfg self.W_Q = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_K = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_V = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_model, cfg.d_head)) self.W_O = t.nn.Parameter(t.empty(cfg.num_heads, cfg.d_head, cfg.d_model)) self.b_Q = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_K = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_V = t.nn.Parameter(t.zeros(cfg.num_heads, cfg.d_head)) self.b_O = t.nn.Parameter(t.zeros(cfg.d_model)) t.nn.init.normal_(self.W_Q, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_K, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_V, mean=0.0, std=cfg.init_range) t.nn.init.normal_(self.W_O, mean=0.0, std=cfg.init_range) def forward(self, params): #TODO: revisit implementing pad_mask with hooks x, pad_mask = params Q = einops.einsum(x, self.W_Q, 'bs dm, h dm dh -> bsh dh') + self.b_Q K = einops.einsum(x, self.W_K, 'bs dm, h dm dh -> bsh dh') + self.b_K V = einops.einsum(x, self.W_V, 'bs dm, h dm dh -> bsh dh') + self.b_V attention_scores = einops.einsum(Q, K, 'b s_q h dh, b s_k h dh -> bh s_q s_k') scaled_attention_scores = attention_scores / (self.cfg.d_head ** 0.5) if self.cfg.causal_attention: scaled_attention_scores = self.apply_causal_mask(scaled_attention_scores) scaled_attention_scores = self.apply_padding_mask(scaled_attention_scores, pad_mask) attention_patterns = t.nn.Softmax(dim=-1)(scaled_attention_scores) post_attention_values = einops.einsum( attention_patterns, V, 'bh s_q s_k, b s_k h dh -> b s_q h dh' ) out = einops.einsum( post_attention_values, self.W_O, 'b s_q h dh, h dh dm -> b s_q dm' ) + self.b_O return out def apply_causal_mask(self, attention_scores): b, h, s_q, s_k = attention_scores.shape mask = t.tril(t.ones(s_q,s_k)).bool() return t.where(mask, attention_scores, self.IGNORE) def apply_padding_mask(self, attention_scores, pad_mask): return t.where(pad_mask, attention_scores, self.IGNORE)



एमएलपी परतें


 class LinearLayer(t.nn.Module): def __init__(self, in_dim, out_dim, include_bias=True): super().__init__() self.include_bias = include_bias self.W = t.nn.Parameter(t.empty(in_dim, out_dim)) t.nn.init.normal_(self.W, mean=0.0, std=cfg.init_range) self.b = None if include_bias: self.b = t.zeros(out_dim) def forward(self, x: Int[Tensor, "batch seq in_dim"]) -> Int[Tensor, "batch seq out_dim"]: out = x @ self.W if self.include_bias: out = out + self.b return out class MLP(t.nn.Module): def __init__(self, cfg): super().__init__() self.in_layer = LinearLayer(cfg.d_model, cfg.d_mlp) self.out_layer = LinearLayer(cfg.d_mlp, cfg.d_model) self.non_linearity = t.nn.ReLU() def forward(self, x): post_W_in = self.in_layer(x) post_non_lin = self.non_linearity(post_W_in) return self.out_layer(post_non_lin)



इसे एक ट्रांसफॉर्मर में एक साथ रखना


 class TransformerBlock(t.nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = LayerNorm(cfg) self.attention = AttentionLayer(cfg) self.ln2 = LayerNorm(cfg) self.mlp = MLP(cfg) def forward(self, params): x, pad_mask = params resid_mid = self.attention((self.ln1(x), pad_mask)) + x resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid return resid_post, pad_mask


 class Transformer(t.nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.embed = EmbedLayer(cfg) self.pos_embed = PositionalEmbedding(cfg) self.final_ln = LayerNorm(cfg) self.unembed = UnEmbedLayer(cfg) self.blocks = t.nn.Sequential(*([TransformerBlock(cfg)] * cfg.num_layers)) def forward(self, x): #TODO: revisit implementing pad_mask with hooks pad_mask = self.get_pad_mask(x) res_post_pos_embed = self.pos_embed(self.embed(x)) post_blocks, _ = self.blocks((res_post_pos_embed, pad_mask)) logits = self.unembed(self.final_ln(post_blocks)) return logits def get_pad_mask(self, x): batch, seq = x.shape return einops.repeat(x != self.cfg.PAD_TOKEN_IDX, 'batch seq -> batch 1 seq_q seq', seq_q=seq)


प्रशिक्षण उपयोगिताएँ


 def cross_entropy_loss(output, targets): log_probs = output.log_softmax(dim=-1) predictions = log_probs[:, 0] batch, out_dim = predictions.shape true_output = predictions[range(batch), targets] return -true_output.sum() / batch def test(model, data, loss_func): inputs, targets = data with t.no_grad(): output = model(inputs) loss = loss_func(output, targets) return loss def train(model, data, optimizer, loss_func): inputs, targets = data optimizer.zero_grad() output = model(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() return loss



प्रशिक्षण कॉन्फ़िगरेशन


 cfg = Config() tokenizer = Tokenizer('()', 12, True) inputs = tokenizer.tokenize([*unbal_random_seq_parens, *random_seq_parens]) targets = t.tensor([*([0] * len(unbal_random_seq_parens)), *([1] * len(random_seq_parens))]) rand_indices = t.randperm(targets.shape[0]) rand_inputs = inputs[rand_indices, :] rand_targets = targets[rand_indices] model = Transformer(cfg) adamW = t.optim.AdamW(model.parameters(), lr=0.01)


वास्तविक प्रशिक्षण


 batch_size = 10000 train_size = int(0.7 * batch_size) epochs = 15 for epoch in range(epochs): for batch_id in range(0, rand_inputs.shape[0], batch_size): rand_inputs_batch, rand_targets_batch = rand_inputs[batch_id : batch_id + batch_size], rand_targets[batch_id : batch_id + batch_size] train_input, train_target = rand_inputs_batch[:train_size, :], rand_targets_batch[:train_size] test_input, test_target = rand_inputs_batch[train_size:, :], rand_targets_batch[train_size:] train(model, (train_input, train_target), adamW, cross_entropy_loss) test_loss = test(model, (test_input, test_target), cross_entropy_loss) print(f'Loss: {test_loss} on epoch: {epoch}/{epochs}') 


प्रशिक्षण संतृप्ति




भाग 3 में, हम इस प्रशिक्षित नेटवर्क के आंतरिक पहलुओं की जांच करेंगे। हम ध्यान पैटर्न को देखकर और मैकेनिस्टिक व्याख्यात्मकता के कुछ निदान उपकरणों को लागू करके ऐसा करेंगे जैसे कि सक्रियण पैचिंग, ताकि यह समझने के लिए एक मैकेनिस्टिक मॉडल बनाया जा सके कि नेटवर्क ने इस कार्य को कैसे हल किया है।


अब तक पढ़ने के लिए धन्यवाद, जल्द ही भाग 3 में आपसे मुलाकात होगी!