Skip to content

Open cp failed when use packed data #15339

@gaojingwei

Description

@gaojingwei

I sft the model using NeMo with data in packed format. It runs normally when cp=1, but when cp > 1, the following error occurs:

File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 400, in forward_step
[rank4]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank4]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/opt/NeMo/nemo/lightning/megatron_parallel.py", line 505, in wrapped_forward_step_func
[rank4]:     batch = _data_step(dataloader_iter)
[rank4]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/opt/NeMo/nemo/collections/llm/gpt/model/base.py", line 667, in data_step
[rank4]:     return self.config.data_step_fn(dataloader_iter)
[rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/opt/NeMo/nemo/collections/llm/gpt/model/base.py", line 111, in gpt_data_step
[rank4]:     output = get_batch_on_this_cp_rank(_batch_required_keys)
[rank4]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]:   File "/opt/megatron-lm/megatron/core/utils.py", line 1838, in get_batch_on_this_cp_rank
[rank4]:     raise e
[rank4]:   File "/opt/megatron-lm/megatron/core/utils.py", line 1825, in get_batch_on_this_cp_rank
[rank4]:     val.shape[seq_dim] // (2 * cp_size),
[rank4]:     ~~~~~~~~~^^^^^^^^^
[rank4]: IndexError: tuple index out of range

I found the source code and added logging:

def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
    """Slice batch input along sequence dimension into multiple chunks,
    which are parallelized across GPUs in a context parallel group.
    """

    # With causal masking, each token only attends to its prior tokens. Simply split
    # sequence into CP chunks can result in severe load imbalance. That's to say, chunks
    # at the end of sequence have bigger workload than others. To address this issue,
    # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
    # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
    # that we can get balanced workload among GPUs in a context parallel group.
    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size > 1:
        cp_rank = parallel_state.get_context_parallel_rank()
        for key, val in batch.items():
            if val is not None:
                seq_dim = 1 if key != "attention_mask" else 2

                try:
                    val = val.view(
                        *val.shape[0:seq_dim],
                        2 * cp_size,
                        val.shape[seq_dim] // (2 * cp_size),
                        *val.shape[(seq_dim + 1) :],
                    )
                    index = torch.zeros(2, dtype=torch.int64, device=val.device)
                    index[0].fill_(cp_rank)
                    index[1].fill_(2 * cp_size - cp_rank - 1)
                    val = val.index_select(seq_dim, index)
                    val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
                    batch[key] = val
                except Exception as e:
                    print(f"Rank {cp_rank}")
                    print(f"Key: {key}")
                    print(f"Shape: {val.shape} {seq_dim}  {val}")
                    raise e

    return batch

The logs show that:
Rank 0
Key: attention_mask
Shape: torch.Size([1]) 2 tensor([1], device='cuda:0')

my data like this,dataset is GPTSFTPackedDataset

({'tokens': tensor([[167543,   8948,    198,  ..., 167545, 167545, 167545]]), 'labels': tensor([[  8948,    198,   2610,  ..., 167545, 167545, 167545]]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 0, 0]]), 'position_ids': tensor([[0, 1, 2,  ..., 0, 0, 0]]), 'attention_mask': tensor([1]), 'cu_seqlens': tensor([[    0,  9484,  9665, 10350, 10643, 11403, 12009, 12691, 13199, 13686,
         14763, 15437, 16731, 17120, 18486, 19185, 20122, 21117, 21284, 21619,
         23190, 24816, 26374, 27024, 27591, 28392, 32342, 32581, 32724, 32768,
            -1]], dtype=torch.int32), 'cu_seqlens_argmin': tensor([[30]]), 'max_seqlen': tensor([[9484]], dtype=torch.int32), 'cu_seqlens_unpadded': tensor([[    0,  9484,  9665, 10350, 10643, 11403, 12009, 12691, 13199, 13686,
         14763, 15437, 16731, 17120, 18486, 19185, 20122, 21117, 21284, 21619,
         23190, 24816, 26374, 27024, 27591, 28392, 32342, 32581, 32724, 32724,
            -1]], dtype=torch.int32), 'cu_seqlens_unpadded_argmin': tensor([[30]]), 'token_count': [32724]}, 0, 0)

I tried setting cp=2 and cp=4, and encountered the same issue in both cases. I also tested it on both NeMo 2504 and NeMo 2509 container。

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions