Skip to content

Commit

Permalink
update main test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 10, 2024
1 parent 03aca7a commit 6874a87
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ def test_non_uniform_interpolation():
n_hat = 50
interpolated = non_uniform_interpolation(pos_embed, 2.0, lambda_factors, n_hat)
assert interpolated.shape == pos_embed.shape


def test_rope_positional_encoding():
rope = RoPEPositionalEncoding(d_model=512, max_len=100)
positions = torch.arange(100).unsqueeze(0)
pos_embeddings = rope(positions)
assert pos_embeddings.shape == (1, 100, 512)

0 comments on commit 6874a87

Please sign in to comment.