Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 10, 2024
1 parent 3568627 commit b4bf313
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions benchmarks/python/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: BSD-3-Clause
import pytest
from nvfuser import FusionDefinition, DataType
from .core import run_benchmark
from .core import run_benchmark, clear_dynamo_cache, with_executor
import torch


Expand Down Expand Up @@ -1743,9 +1743,9 @@ def test_rope_variations_nvf_benchmark(
run_benchmark(benchmark, fd.execute, inputs)


def toy_rope(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int
) -> torch.Tensor:
# [x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int]
def toy_rope(inputs: list):
x, cos, sin, rope_n_elem = inputs
x_rope = x[..., :rope_n_elem]
x1 = x_rope[..., : rope_n_elem // 2] # (B, nh, T, hs/2)
x2 = x_rope[..., rope_n_elem // 2 :] # (B, nh, T, hs/2)
Expand Down Expand Up @@ -1782,5 +1782,5 @@ def test_toy_rope_benchmark(
run_benchmark(
benchmark,
benchmark_fn,
[x, weight, bias, num_groups],
[x, cos, sin, rope_n_elem],
)

0 comments on commit b4bf313

Please sign in to comment.