-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathllama31.py
720 lines (646 loc) · 28 KB
/
llama31.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
"""
nano-Llama 3.1
Simpler version you can just forward on 1 GPU, without torchrun.
Changes:
- replace ColumnParallelLinear -> Linear
- replace RowParallelLinear -> Linear
- replace VocabParallelEmbedding -> Embedding
Run example:
python llama31.py \
--ckpt_dir llama-models/models/llama3_1/Meta-Llama-3.1-8B \
--tokenizer_path llama-models/models/llama3_1/Meta-Llama-3.1-8B/tokenizer.model
"""
import os
import glob
import fire
import time
import json
import math
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Tuple, TypedDict
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from tokenizer import Tokenizer
# -----------------------------------------------------------------------------
# ModelArgs
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
max_batch_size: int = 32
max_seq_len: int = 2048
flash: bool = False # use flash attention?
def __init__(self, **kwargs):
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0
# -----------------------------------------------------------------------------
# Transformer
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor):
# RoPE scaling (values obtained from grid search)
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
freqs_cis_real = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return freqs_cis_real
def apply_rotary_emb(x, freqs_cis):
# shape gymnastics let's go
# x is (bs, seqlen, n_heads, head_dim), e.g. (4, 8, 32, 128)
# freqs_cis is (seq_len, head_dim/2, 2), e.g. (8, 64, 2)
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
# xshaped is (bs, seqlen, n_heads, head_dim/2, 2), e.g. (4, 8, 32, 64, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
# freqs_cis becomes (1, seqlen, 1, head_dim/2, 2), e.g. (1, 8, 1, 64, 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
# x_out2 at this point is (bs, seqlen, n_heads, head_dim/2, 2), e.g. (4, 8, 32, 64, 2)
x_out2 = x_out2.flatten(3)
# x_out2 is now (bs, seqlen, n_heads, head_dim), e.g. (4, 8, 32, 128)
return x_out2.type_as(x)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class KVCache(nn.Module):
def __init__(self, batch_size, seq_length, n_kv_heads, head_dim, dtype, device):
super().__init__()
cache_shape = (batch_size, seq_length, n_kv_heads, head_dim)
self.register_buffer("cache_k", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer("cache_v", torch.zeros(cache_shape, dtype=dtype, device=device))
def update(self, start_pos, xk, xv):
seqlen = xk.size(1)
self.cache_k[:, start_pos : start_pos + seqlen] = xk
self.cache_v[:, start_pos : start_pos + seqlen] = xv
xk = self.cache_k[:, : start_pos + seqlen]
xv = self.cache_v[:, : start_pos + seqlen]
return xk, xv
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.flash = args.flash # use flash attention?
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# model_parallel_size = fs_init.get_model_parallel_world_size()
model_parallel_size = 1 # AK: model parallel size is 1 for 1 GPU
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False )
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# will be KVCache object managed by inference context manager
self.cache = None
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
# calculate query, key, value and split out heads
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# rotate query, keys (RoPE)
xq = apply_rotary_emb(xq, freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis)
# KV cache update
if self.cache is not None:
# update the KV cache with current KV and get all the previous KVs
xk, xv = self.cache.update(start_pos, xk, xv)
# repeat k/v heads if n_kv_heads < n_heads (GQA)
xk = repeat_kv(xk, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
# make heads be a batch dim
xq, xk, xv = (x.transpose(1, 2) for x in (xq, xk, xv))
# attention
if self.flash:
output = F.scaled_dot_product_attention(xq, xk, xv, mask)
else:
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
# concatenate all the heads
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# output projection
proj = self.wo(output)
return proj
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
# hidden dim gymnastics that Meta simplified only later
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.layers = nn.ModuleList(
TransformerBlock(params) for _ in range(params.n_layers)
)
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
def forward_inference(self, tokens: torch.Tensor, start_pos: int):
# for use during inference
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
def forward_loss(self, inputs: torch.Tensor, targets: torch.Tensor, ignore_index=-100):
# for use during training
# ignore_index can be set to e.g. self.tokenizer.pad_id in the future
# forward the model first
_bsz, seqlen = inputs.shape
h = self.tok_embeddings(inputs)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[:seqlen]
mask = torch.full((seqlen, seqlen), float("-inf"), device=inputs.device)
mask = torch.triu(mask, diagonal=1)
mask = mask.type_as(h)
start_pos = -1 # -1 disables KV caching logic
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
logits = self.output(h).float()
# and then loss
loss = F.cross_entropy(
input=logits.transpose(1, 2),
target=targets,
reduction="mean",
ignore_index=ignore_index,
)
return loss
def configure_optimizers(self, learning_rate, weight_decay=0.0, betas=(0.9, 0.97), device_type='cuda'):
train_params = []
finetune_type = "all"
if finetune_type == "rmsnorm":
# let's only train the RMSNorm parameters to start
for name, param in self.named_parameters():
if "norm" in name:
train_params.append(param)
elif finetune_type == "all":
# let's train all parameters
for param in self.parameters():
train_params.append(param)
elif finetune_type == "all_no_pos":
# let's train all parameters except the positional embeddings and lm_head
n, m = 0, 0
for name, param in self.named_parameters():
if name == "output.weight":
# do not include
n += 1
continue
elif name == "tok_embeddings.weight":
# do not include and also does not require grad
m += 1
param.requires_grad = False
else:
# do include
train_params.append(param)
assert n == 1, "did not find output.weight"
assert m == 1, "did not find tok_embeddings.weight"
print("number of parameters: ", sum(p.numel() for p in self.parameters()))
print("number of trainable parameters: ", sum(p.numel() for p in train_params))
# Create AdamW optimizer and use the fused version if it is available
fused_available = True #'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(train_params, lr=learning_rate, betas=betas, **extra_args)
return optimizer
# -----------------------------------------------------------------------------
# Llama wrapper
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
flash: bool = False,
model_parallel_size: Optional[int] = 1,
seed: int = 1,
) -> "Llama":
assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}."
assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist."
assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."
local_rank = 0
torch.cuda.set_device(local_rank)
torch.manual_seed(seed) # seed must be the same in all processes
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(checkpoints)
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
flash=flash,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer)
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
sample_rng: torch.Generator,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
"""
params = self.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
# install KV cache in all the Attention layers
for block in self.model.layers:
layer_dtype = block.attention.wq.weight.dtype
layer_device = block.attention.wq.weight.device
block.attention.cache = KVCache(
batch_size=bsz,
seq_length=total_len,
n_kv_heads=params.n_kv_heads,
head_dim=params.dim // params.n_heads,
dtype=layer_dtype,
device=layer_device,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward_inference(tokens, prev_pos)
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
for cur_pos in range(min_prompt_len, total_len):
# get the logits for the next token in all the batch rows
logits = self.model.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos)
# sample the next token
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p, sample_rng)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
prev_pos = cur_pos
if all(eos_reached):
break
out_tokens = []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
except ValueError:
pass
out_tokens.append(toks)
# clean up the KV cache in all the layers
for block in self.model.layers:
block.attention.cache = None
return out_tokens
def text_completion(
self,
prompts: List[str],
sample_rng: torch.Generator,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
echo: bool = False,
):
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
# encode the (string) prompts to tokens
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
# generate the completions in tokens space
generation_tokens = self.generate(
prompt_tokens=prompt_tokens,
sample_rng=sample_rng,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
echo=echo,
)
# decode the completions back to strings
completions = [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
return completions
def sample_top_p(probs, p, generator):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1, generator=generator)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
# -----------------------------------------------------------------------------
# distributed and sharded data loader
def _peek_data_shard(filename):
# only reads the header, returns header data
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
if header[0] != 20240801:
print("ERROR: magic number mismatch in the data .bin file!")
exit(1)
assert header[1] == 7, "unsupported version"
ntok = header[2] # number of tokens (claimed)
return ntok # for now just return the number of tokens
def _load_data_shard(filename):
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
assert header[0] == 20240801, "magic number mismatch in the data .bin file"
assert header[1] == 7, "unsupported version"
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint32)
assert len(tokens) == ntok, "number of tokens read does not match header?"
return tokens
class DistributedShardedDataLoader:
"""
This DataLoader is both:
- distributed (works correctly in case of multiple processes in DDP)
- sharded (supports datasets that are broken up into multiple data shards)
It is not *permuted*, meaning that it itearates over the data in the order
of the dataset on disk, so the user should make sure to shuffle their examples
during the creation of their data shards for best performance.
"""
def __init__(self, filename_pattern, B, T, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.B = B
self.T = T
# glob files that match the pattern
self.files = sorted(glob.glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
ntok_total = 0
for fname in self.files:
shard_ntok = _peek_data_shard(fname)
assert shard_ntok >= num_processes * B * T + 1
ntok_total += shard_ntok
self.ntok_total = ntok_total
print(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files")
# kick things off
self.current_shard = None
self.reset()
def reset(self):
# we're being a bit clever here: if we already had shard 0 loaded,
# then don't do the work to reload it, just reset the pointer
if self.current_shard != 0:
self.current_shard = 0
self.tokens = _load_data_shard(self.files[self.current_shard])
self.current_position = self.process_rank * self.B * self.T
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.B * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def next_batch(self):
B = self.B
T = self.T
buf = self.tokens[self.current_position : self.current_position+B*T+1]
buf = torch.tensor(buf, dtype=torch.long)
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance the start pointer in current shard
self.current_position += B * T * self.num_processes
# if loading the next batch would be out of bounds advance the shard
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.advance()
return x, y
# -----------------------------------------------------------------------------
# int main
def main(
ckpt_dir: str = "llama-models/models/llama3_1/Meta-Llama-3.1-8B",
tokenizer_path: str = "llama-models/models/llama3_1/Meta-Llama-3.1-8B/tokenizer.model",
temperature: float = 1.0,
top_p: float = 0.9,
max_seq_len: int = 256,
max_gen_len: int = 256,
max_batch_size: int = 8,
flash: bool = True,
):
# load the val data shard
data_loader = DistributedShardedDataLoader(
filename_pattern="tinystories/*_val.bin",
B=max_batch_size,
T=max_seq_len,
process_rank=0,
num_processes=1,
)
llama = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
flash=flash,
)
total_batch_size = max_batch_size * max_seq_len
print(f"total_batch_size: {total_batch_size}")
# super simple training loop to start
model = llama.model
model.train()
optimizer = model.configure_optimizers(learning_rate=1e-5, weight_decay=0.0)
for step in range(20):
optimizer.zero_grad()
x, y = data_loader.next_batch()
x, y = x.cuda(), y.cuda()
loss = model.forward_loss(x, y)
loss.backward()
optimizer.step()
print(f"step {step}, loss: {loss.item()}")
# and now generate
model.eval()
prompts: List[str] = [
"Once upon a time",
"One day",
"Lily and George were best friends",
"On a dark and stormy night",
]
sample_rng = torch.Generator(device='cuda')
sample_rng.manual_seed(1337)
t0 = time.time()
results = llama.text_completion(
prompts,
sample_rng=sample_rng,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
t1 = time.time()
print(f"Generated in {t1 - t0:.2f} seconds")
for prompt, result in zip(prompts, results):
print(prompt, end="") # AK: change end="\n" to end=""
print(f"{result['generation']}")
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)