Skip to content

Commit

Permalink
update training test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 10, 2024
1 parent bc4a9c8 commit 1de4958
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,19 @@ def test_validate_targets():
targets = [[1, 2, 3], [4, 5, 6]]
vocab_size = 10
assert validate_targets(targets, vocab_size) == True


def test_train():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
sequences = [[1, 2, 3], [4, 5, 6]]
targets = [[2, 3, 4], [5, 6, 7]]
dataset = CustomDataset(sequences, targets)
train_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
val_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
device = torch.device("cpu")
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
train(model, train_loader, val_loader, optimizer, criterion, device, epochs=1)

0 comments on commit 1de4958

Please sign in to comment.