Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add context parallel support for packed dataset in THD format #9540

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict, Iterator, List, Optional, Union

import torch
import transformer_engine_extensions as tex
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pkg_resources import packaging
Expand Down Expand Up @@ -1175,22 +1176,23 @@ def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

# check if the batch is not in THD format
if 'cu_seqlens' not in batch:
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch
Expand Down Expand Up @@ -1261,6 +1263,26 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
)
raise e

# get packed sequences for this context parallel rank
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key in required_keys:
val = batch[key]
if key != "cu_seqlens":
seq_dim = 1 if key != 'attention_mask' else 2
index = tex.thd_get_partitioned_indices(
cu_seqlens, val.size(seq_dim), cp_size, cp_rank
)
val = val.index_select(seq_dim, index)
batch[key] = val
cu_seqlens = cu_seqlens // cp_size
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the position_ids means the token_id in packed sequence? how is this argument used in training fwd and bwd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The position_ids is the position of the tokens in a sequence (e.g. [0,1,2, ... , seq_len-1]). In a packed sequence, we have a list of position_ids since the packed sequence is composed of many individual sequences. I'm not too sure if that's what you mean by token_id. It's used the same way as input_ids in training fwd and bwd.

'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'] if 'labels' in batch else None,
}
forward_args['packed_seq_params'] = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
Expand Down
36 changes: 33 additions & 3 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Tuple

import numpy as np
import torch

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'torch' is not used.

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
Expand Down Expand Up @@ -83,12 +84,20 @@
# using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings
# are identical to normal SFT training
data_cfg = cfg.model.data.train_ds
pad_seq_length_to_mult = 16
cp_size = cfg.model.context_parallel_size

# if context parallel is used, each individual data length in one packed dataset sample
# needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641
if cp_size > 1:
pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2)

dataset = GPTSFTDataset(
file_path=data_cfg.file_names[0],
tokenizer=get_nmt_tokenizer(library="sentencepiece", tokenizer_model=cfg.tokenizer_path),
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here
pad_seq_length_to_mult=pad_seq_length_to_mult,
add_bos=data_cfg.get('add_bos', False),
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
Expand All @@ -109,8 +118,29 @@
special_tokens=data_cfg.get('chat_prompt_tokens', None),
is_test=True,
)

return np.array([dataset[i] for i in range(len(dataset))])
max_seq_length = dataset.max_seq_length
pad_id = dataset.tokenizer.eos_id
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
dataset = np.array([dataset[i] for i in range(len(dataset))])
if cp_size > 1:

def pre_pad_dataset(data, max_length, pad_id):
'''
pad each individual data point to the length of max_length
'''
for key, val in data.items():
if key in {'input_ids', 'context_ids'}:
# because input_ids is truncated by 1 for labels in the collate_fn of GPTSFTPackedDataset
# in gpt_sft_dataset.py, we add 1 extra padding here
val = val + [pad_id] * (max_length - len(val) + 1)
data[key] = val
return

ceil_to_nearest = lambda n, m: (n + m - 1) // m * m
for data in dataset:
max_length = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult))
pre_pad_dataset(data, max_length, pad_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How the loss_mask is handled for padded tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return dataset


@dataclass
Expand Down
Loading