-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Add context parallel support for packed dataset in THD format #9540
Changes from 4 commits
c938bdd
525003e
3c69f8e
9d01092
6d240de
b9a5af4
9b506fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from typing import TYPE_CHECKING, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
Check notice Code scanning / CodeQL Unused import Note
Import of 'torch' is not used.
|
||
|
||
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset | ||
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer | ||
|
@@ -83,12 +84,20 @@ | |
# using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings | ||
# are identical to normal SFT training | ||
data_cfg = cfg.model.data.train_ds | ||
pad_seq_length_to_mult = 16 | ||
cp_size = cfg.model.context_parallel_size | ||
|
||
# if context parallel is used, each individual data length in one packed dataset sample | ||
# needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641 | ||
if cp_size > 1: | ||
pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2) | ||
|
||
dataset = GPTSFTDataset( | ||
file_path=data_cfg.file_names[0], | ||
tokenizer=get_nmt_tokenizer(library="sentencepiece", tokenizer_model=cfg.tokenizer_path), | ||
max_seq_length=data_cfg.max_seq_length, | ||
min_seq_length=data_cfg.min_seq_length, | ||
pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here | ||
pad_seq_length_to_mult=pad_seq_length_to_mult, | ||
add_bos=data_cfg.get('add_bos', False), | ||
add_eos=data_cfg.get('add_eos', True), | ||
add_sep=data_cfg.get('add_sep', False), | ||
|
@@ -109,8 +118,29 @@ | |
special_tokens=data_cfg.get('chat_prompt_tokens', None), | ||
is_test=True, | ||
) | ||
|
||
return np.array([dataset[i] for i in range(len(dataset))]) | ||
max_seq_length = dataset.max_seq_length | ||
pad_id = dataset.tokenizer.eos_id | ||
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult | ||
dataset = np.array([dataset[i] for i in range(len(dataset))]) | ||
if cp_size > 1: | ||
|
||
def pre_pad_dataset(data, max_length, pad_id): | ||
''' | ||
pad each individual data point to the length of max_length | ||
''' | ||
for key, val in data.items(): | ||
if key in {'input_ids', 'context_ids'}: | ||
# because input_ids is truncated by 1 for labels in the collate_fn of GPTSFTPackedDataset | ||
# in gpt_sft_dataset.py, we add 1 extra padding here | ||
val = val + [pad_id] * (max_length - len(val) + 1) | ||
data[key] = val | ||
return | ||
|
||
ceil_to_nearest = lambda n, m: (n + m - 1) // m * m | ||
for data in dataset: | ||
max_length = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult)) | ||
pre_pad_dataset(data, max_length, pad_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How the loss_mask is handled for padded tokens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loss_mask is handled in the packing function here: https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py#L142 |
||
return dataset | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the position_ids means the token_id in packed sequence? how is this argument used in training fwd and bwd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The position_ids is the position of the tokens in a sequence (e.g. [0,1,2, ... , seq_len-1]). In a packed sequence, we have a list of position_ids since the packed sequence is composed of many individual sequences. I'm not too sure if that's what you mean by token_id. It's used the same way as input_ids in training fwd and bwd.