Skip to content

Commit

Permalink
fix qwen2 import failure in test (#394)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
ByronHsu authored Nov 19, 2024
1 parent 8e72763 commit 11ec97b
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions test/transformers/test_qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@

import pytest
import torch
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLRotaryEmbedding,
apply_multimodal_rotary_pos_emb,
)

try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLRotaryEmbedding,
apply_multimodal_rotary_pos_emb,
)

IS_QWEN_AVAILABLE = True
except Exception:
IS_QWEN_AVAILABLE = False

from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
from liger_kernel.transformers.functional import liger_qwen2vl_mrope
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb


@pytest.mark.skipif(
not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers."
)
@pytest.mark.parametrize("bsz", [1, 2])
@pytest.mark.parametrize("seq_len", [128, 131])
@pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)])
Expand Down Expand Up @@ -87,6 +96,9 @@ def test_correctness(
torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol)


@pytest.mark.skipif(
not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers."
)
@pytest.mark.parametrize(
"bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section",
[
Expand Down

0 comments on commit 11ec97b

Please sign in to comment.