Skip to content

Commit

Permalink
update main test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 10, 2024
1 parent 6874a87 commit 042d120
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ def test_rope_positional_encoding():
positions = torch.arange(100).unsqueeze(0)
pos_embeddings = rope(positions)
assert pos_embeddings.shape == (1, 100, 512)


def test_longrope_model_forward():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
)
input_ids = torch.randint(0, 50257, (2, 1024))
output = model(input_ids)
assert output.shape == (2, 1024, 512)

0 comments on commit 042d120

Please sign in to comment.