diff --git a/tests/test_main.py b/tests/test_main.py index e00ffde..9ab9041 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -31,3 +31,12 @@ def test_rope_positional_encoding(): positions = torch.arange(100).unsqueeze(0) pos_embeddings = rope(positions) assert pos_embeddings.shape == (1, 100, 512) + + +def test_longrope_model_forward(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + input_ids = torch.randint(0, 50257, (2, 1024)) + output = model(input_ids) + assert output.shape == (2, 1024, 512)