Skip to content

Commit

Permalink
fixed position ids
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Apr 4, 2024
1 parent 81f5899 commit 96b3051
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions mytests/multimodal_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@
import torch
from megatron import utils

def test_get_sequence_ids():
def test_get_sequence_and_position_ids():
eps = 1e-7

# Case 1: 1 sequence
tokens = torch.tensor([[1, 2, 3]])
sequence_ids = utils.get_sequence_ids(tokens, eos_token_id=3, bos_token_id=None)
sequence_ids, position_ids = utils.get_sequence_and_position_ids(tokens, eos_token_id=3, bos_token_id=None)
assert torch.all(torch.abs(sequence_ids - torch.tensor([[0, 0, 0]])) < eps)
assert torch.all(torch.abs(position_ids - torch.tensor([[0, 1, 0]])) < eps)

# Case 1: 2 sequence
tokens = torch.tensor([[1, 2, 3, 4, 5, 6]])
sequence_ids = utils.get_sequence_ids(tokens, eos_token_id=3, bos_token_id=None)
sequence_ids, position_ids = utils.get_sequence_and_position_ids(tokens, eos_token_id=3, bos_token_id=None)
assert torch.all(torch.abs(sequence_ids - torch.tensor([[0, 0, 0, 1, 1, 1]])) < eps)
assert torch.all(torch.abs(position_ids - torch.tensor([[0, 1, 0, 1, 2, 3]])) < eps)

# Case 1: multiple samples
tokens = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 1, 3, 1, 3, 4]])
sequence_ids = utils.get_sequence_ids(tokens, eos_token_id=3, bos_token_id=None)
sequence_ids, position_ids = utils.get_sequence_and_position_ids(tokens, eos_token_id=3, bos_token_id=None)
assert torch.all(torch.abs(sequence_ids - torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 2]])) < eps)

assert torch.all(torch.abs(position_ids - torch.tensor([[0, 1, 0, 1, 2, 3], [0, 1, 0, 1, 0, 1]])) < eps)

def test_get_shifted_multimodal_position_ids():
eps = 1e-7

Expand Down Expand Up @@ -317,7 +320,7 @@ def test_get_multimodal_attn_mask():

shifted_text_positions, shifted_vision_positions, shited_audio_positions = utils.get_shifted_multimodal_position_ids(input_info, position_pad_id=-1)
shifted_multimodal_position_ids = torch.cat((shifted_text_positions, shifted_vision_positions), dim=1)
attention_mask = utils.get_multimodal_attn_mask(
attention_mask, position_ids = utils.get_multimodal_attn_mask(
text_tokens=input_info["text"]["input"],
vision_positions=input_info["vision"]["positions"],
audio_positions=None,
Expand All @@ -339,7 +342,7 @@ def test_get_multimodal_attn_mask():
text_positions = torch.tensor([[0, 2, 3, 5, 6], [0, 1, 2, 4, -1]])
vision_positions = torch.tensor([[1, 4, -1], [3, -1, -1]])
'''
correct_mask =torch.tensor([
correct_mask = torch.tensor([
[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
Expand Down Expand Up @@ -367,9 +370,14 @@ def test_get_multimodal_attn_mask():
correct_mask = correct_mask < 0.5
assert torch.all(attention_mask == correct_mask)

correct_positions = torch.tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4]], dtype=torch.int32)
assert torch.all(position_ids == correct_positions)


# Main function
def main():
test_get_sequence_ids()
test_get_sequence_and_position_ids()
test_get_shifted_multimodal_position_ids()
test_get_proxy_tokens()
test_get_multimodal_mask()
Expand Down

0 comments on commit 96b3051

Please sign in to comment.