Skip to content

Commit

Permalink
fixed position ids
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Apr 4, 2024
1 parent 96b3051 commit eb50ac1
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,34 @@ def get_ltor_masks_and_position_ids(

return attention_mask, loss_mask, position_ids

def get_position_ids_concat(is_separator):
# convert bools to ints
is_separator = is_separator.int()

def get_sequence_ids(tokens, eos_token_id, bos_token_id):
# Prepare an array with the same shape, all elements zero
result = torch.zeros_like(is_separator)

# Iterate through each row
for row_index in range(is_separator.shape[0]):
# Get the indices where the value is 1 in the current row
indices = list(torch.where(is_separator[row_index] == 1)[0])

# Include the 0th index
indices.insert(0, 0)

# Iterate through each index where value is 1
for i, index in enumerate(indices):

# If it's the last '1' in the row, fill from the index to the row end
if i == len(indices) - 1:
result[row_index, index:] = torch.arange(0, is_separator.shape[1] - index)

# Else, fill up to the next '1'
else:
result[row_index, index:indices[i+1]] = torch.arange(0, indices[i+1] - index)
return result

def get_sequence_and_position_ids(tokens, eos_token_id, bos_token_id):
if (eos_token_id is None) and (bos_token_id is None):
raise ValueError(
'Must supply a value for either eos_token_id or bos_token_id, but got None for both.'
Expand All @@ -126,8 +152,8 @@ def get_sequence_ids(tokens, eos_token_id, bos_token_id):
split_token_id = bos_token_id
bos_mode = True

is_separator = torch.eq(tokens,
split_token_id) # type: ignore
is_separator = torch.eq(tokens, split_token_id) # type: ignore
position_ids = get_position_ids_concat(is_separator)
cumulative_sep = torch.cumsum(is_separator,
dim=1).to(tokens.dtype)
# If separator token is bos, we're already done
Expand All @@ -136,7 +162,7 @@ def get_sequence_ids(tokens, eos_token_id, bos_token_id):

# If separator token is eos, right shift 1 space
left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1), position_ids

def get_shifted_multimodal_position_ids(input_info, position_pad_id=-1):

Expand Down Expand Up @@ -249,20 +275,24 @@ def get_multimodal_attn_mask(

# if attn_uses_sequence_id, then mask across sequence ids to prevent cross sequence attention
if concat_data:
sequence_ids = get_sequence_ids(interleaved_tokens, eos_token_id, bos_token_id)
sequence_ids, position_ids = get_sequence_and_position_ids(interleaved_tokens, eos_token_id, bos_token_id)
if attn_uses_sequence_id:
sequence_ids_3d_1 = sequence_ids.unsqueeze(-1)
sequence_ids_3d_2 = sequence_ids.unsqueeze(1)
matching_sequence_ids = torch.eq(sequence_ids_3d_1, sequence_ids_3d_2)
mask = mask * matching_sequence_ids
else:
position_ids = torch.arange(input_seq_length, dtype=torch.long, device=device) # FIX THIS #TODO
position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length)
else:
sequence_ids = None
position_ids = torch.arange(input_seq_length, dtype=torch.long, device=device) # FIX THIS #TODO
position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length)
# convert to binary
mask = mask.view(
batch_size, 1, input_seq_length, input_seq_length
)
return mask < 0.5, sequence_ids

return mask < 0.5, position_ids

def get_multimodal_ltor_masks_and_position_ids(
input_info,
Expand All @@ -289,7 +319,7 @@ def get_multimodal_ltor_masks_and_position_ids(
shifted_multimodal_position_ids = torch.cat((shifted_text_positions, shifted_vision_positions), dim=1)

# Attention mask (lower triangular). # TODO: INCLUDE Audio in this
attention_mask, sequence_ids = get_multimodal_attn_mask(
attention_mask, position_ids = get_multimodal_attn_mask(
text_tokens=input_info["text"]["input"],
vision_positions=input_info["vision"]["positions"],
audio_positions=None,
Expand Down Expand Up @@ -327,12 +357,6 @@ def get_multimodal_ltor_masks_and_position_ids(
loss_mask[labels == vision_input_end_token] = 0.0
# loss_mask[labels == vision_gen_start_token] = 0.0

# if concat_data:

# else:

position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device) # FIX THIS #TODO
position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length)
return attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels

def local_rank():
Expand Down

0 comments on commit eb50ac1

Please sign in to comment.