Skip to content

Commit

Permalink
update train file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 14, 2024
1 parent 6a7f3de commit aa5b4b7
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# Initialize the accelerator
accelerator = Accelerator()


# %%
class CustomDataset(Dataset):
"""Custom dataset for handling sequences and targets."""
Expand Down Expand Up @@ -183,15 +182,21 @@ def main():
)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LongRoPEModel(
d_model=4096,
n_heads=32,
num_layers=6,
vocab_size=tokenizer.vocab_size,
max_len=2048000, # Set max_len to 2048k tokens
).to(device)
)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Prepare everything with accelerator
model, optimizer, train_loader, val_loader = accelerator.prepare(
model, optimizer, train_loader, val_loader
)

extended_model = model.extend_context(
data_path="../data/raw/enwik8.gz",
Expand All @@ -213,7 +218,7 @@ def main():
optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# train(recovered_model, train_loader, val_loader, optimizer, criterion, device)
train(recovered_model, train_loader, val_loader, optimizer, criterion)


if __name__ == "__main__":
Expand Down

0 comments on commit aa5b4b7

Please sign in to comment.