1+ import torch
2+ import torch .nn as nn
3+ from torch .nn import functional as F
4+
5+ # hyperparameters
6+ batch_size = 16 # 64 # how many independent sequences will we process in parallel?
7+ block_size = 32 #256 # what is the maximum context length for predictions?
8+
9+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
10+ eval_iters = 200
11+ n_embd = 384
12+ n_head = 6
13+ n_layer = 6
14+ dropout = 0.2
15+ # ------------
16+
17+
18+ # data loading
19+ def get_batch (split , train_data , val_data ):
20+ # generate a small batch of data of inputs x and targets y
21+ data = train_data if split == 'train' else val_data
22+ ix = torch .randint (len (data ) - block_size , (batch_size ,))
23+ x = torch .stack ([data [i :i + block_size ] for i in ix ])
24+ y = torch .stack ([data [i + 1 :i + block_size + 1 ] for i in ix ])
25+ x , y = x .to (device ), y .to (device )
26+ return x , y
27+
28+
29+ @torch .no_grad ()
30+ def estimate_loss (model , train_data , val_data ):
31+ out = {}
32+ model .eval ()
33+ for split in ['train' , 'val' ]:
34+ losses = torch .zeros (eval_iters )
35+ for k in range (eval_iters ):
36+ X , Y = get_batch (split , train_data , val_data )
37+ logits , loss = model (X , Y )
38+ losses [k ] = loss .item ()
39+ out [split ] = losses .mean ()
40+ model .train ()
41+ return out
42+
43+
44+ class Head (nn .Module ):
45+ """ one head of self-attention """
46+
47+ def __init__ (self , head_size ):
48+ super ().__init__ ()
49+ self .key = nn .Linear (n_embd , head_size , bias = False )
50+ self .query = nn .Linear (n_embd , head_size , bias = False )
51+ self .value = nn .Linear (n_embd , head_size , bias = False )
52+ self .register_buffer ('tril' , torch .tril (torch .ones (block_size , block_size )))
53+
54+ self .dropout = nn .Dropout (dropout )
55+
56+ def forward (self , x ):
57+ # input of size (batch, time-step, channels)
58+ # output of size (batch, time-step, head size)
59+ B ,T ,C = x .shape
60+ k = self .key (x ) # (B,T,hs)
61+ q = self .query (x ) # (B,T,hs)
62+ # compute attention scores ("affinities")
63+ wei = q @ k .transpose (- 2 ,- 1 ) * k .shape [- 1 ]** - 0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
64+ wei = wei .masked_fill (self .tril [:T , :T ] == 0 , float ('-inf' )) # (B, T, T)
65+ wei = F .softmax (wei , dim = - 1 ) # (B, T, T)
66+ wei = self .dropout (wei )
67+ # perform the weighted aggregation of the values
68+ v = self .value (x ) # (B,T,hs)
69+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
70+ return out
71+
72+
73+ class MultiHeadAttention (nn .Module ):
74+ """ multiple heads of self-attention in parallel """
75+
76+ def __init__ (self , num_heads , head_size ):
77+ super ().__init__ ()
78+ self .heads = nn .ModuleList ([Head (head_size ) for _ in range (num_heads )])
79+ self .proj = nn .Linear (head_size * num_heads , n_embd )
80+ self .dropout = nn .Dropout (dropout )
81+
82+ def forward (self , x ):
83+ out = torch .cat ([h (x ) for h in self .heads ], dim = - 1 )
84+ out = self .dropout (self .proj (out ))
85+ return out
86+
87+
88+ class FeedFoward (nn .Module ):
89+ """ a simple linear layer followed by a non-linearity """
90+
91+ def __init__ (self , n_embd ):
92+ super ().__init__ ()
93+ self .net = nn .Sequential (
94+ nn .Linear (n_embd , 4 * n_embd ),
95+ nn .ReLU (),
96+ nn .Linear (4 * n_embd , n_embd ),
97+ nn .Dropout (dropout ),
98+ )
99+
100+ def forward (self , x ):
101+ return self .net (x )
102+
103+
104+ class Block (nn .Module ):
105+ """ Transformer block: communication followed by computation """
106+
107+ def __init__ (self , n_embd , n_head ):
108+ # n_embd: embedding dimension, n_head: the number of heads we'd like
109+ super ().__init__ ()
110+ head_size = n_embd // n_head
111+ self .sa = MultiHeadAttention (n_head , head_size )
112+ self .ffwd = FeedFoward (n_embd )
113+ self .ln1 = nn .LayerNorm (n_embd )
114+ self .ln2 = nn .LayerNorm (n_embd )
115+
116+ def forward (self , x ):
117+ x = x + self .sa (self .ln1 (x ))
118+ x = x + self .ffwd (self .ln2 (x ))
119+ return x
120+
121+
122+ class GPT (nn .Module ):
123+
124+ def __init__ (self , vocab_size ):
125+ super ().__init__ ()
126+ # each token directly reads off the logits for the next token from a lookup table
127+ self .token_embedding_table = nn .Embedding (vocab_size , n_embd )
128+ self .position_embedding_table = nn .Embedding (block_size , n_embd )
129+ self .blocks = nn .Sequential (* [Block (n_embd , n_head = n_head ) for _ in range (n_layer )])
130+ self .ln_f = nn .LayerNorm (n_embd ) # final layer norm
131+ self .lm_head = nn .Linear (n_embd , vocab_size )
132+
133+ # better init, not covered in the original GPT video, but important, will cover in followup video
134+ self .apply (self ._init_weights )
135+
136+ def _init_weights (self , module ):
137+ if isinstance (module , nn .Linear ):
138+ torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 0.02 )
139+ if module .bias is not None :
140+ torch .nn .init .zeros_ (module .bias )
141+ elif isinstance (module , nn .Embedding ):
142+ torch .nn .init .normal_ (module .weight , mean = 0.0 , std = 0.02 )
143+
144+ def forward (self , idx , targets = None ):
145+ B , T = idx .shape
146+
147+ # idx and targets are both (B,T) tensor of integers
148+ tok_emb = self .token_embedding_table (idx ) # (B,T,C)
149+ pos_emb = self .position_embedding_table (torch .arange (T , device = device )) # (T,C)
150+ x = tok_emb + pos_emb # (B,T,C)
151+ x = self .blocks (x ) # (B,T,C)
152+ x = self .ln_f (x ) # (B,T,C)
153+ logits = self .lm_head (x ) # (B,T,vocab_size)
154+
155+ if targets is None :
156+ loss = None
157+ else :
158+ B , T , C = logits .shape
159+ logits = logits .view (B * T , C )
160+ targets = targets .view (B * T )
161+ loss = F .cross_entropy (logits , targets )
162+
163+ return logits , loss
164+
165+ def generate (self , idx , max_new_tokens ):
166+ # idx is (B, T) array of indices in the current context
167+ for _ in range (max_new_tokens ):
168+ # crop idx to the last block_size tokens
169+ idx_cond = idx [:, - block_size :]
170+ # get the predictions
171+ logits , loss = self (idx_cond )
172+ # focus only on the last time step
173+ logits = logits [:, - 1 , :] # becomes (B, C)
174+ # apply softmax to get probabilities
175+ probs = F .softmax (logits , dim = - 1 ) # (B, C)
176+ # sample from the distribution
177+ idx_next = torch .multinomial (probs , num_samples = 1 ) # (B, 1)
178+ # append sampled index to the running sequence
179+ idx = torch .cat ((idx , idx_next ), dim = 1 ) # (B, T+1)
180+ return idx
0 commit comments