Skip to content

Commit

Permalink
Merge branch 'renormalize-blend-weights' into 'main'
Browse files Browse the repository at this point in the history
Add option to renormalize blend weights

See merge request ADLR/megatron-lm!1797
  • Loading branch information
ericharper committed Aug 8, 2024
2 parents 0363328 + bf3e0b9 commit 2c47ea2
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 9 deletions.
1 change: 1 addition & 0 deletions megatron/core/datasets/blended_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
unique_identifiers["split"] = self.split.name
unique_identifiers["weights"] = self.weights
unique_identifiers["size"] = self.size
unique_identifiers["renormalize_blend_weights"] = self.config.renormalize_blend_weights

self.unique_description = json.dumps(
unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
Expand Down
13 changes: 10 additions & 3 deletions megatron/core/datasets/blended_megatron_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def build(self) -> List[Optional[TopLevelDataset]]:
for i, dataset_and_size in enumerate(zip(dataset.datasets, sizes)):
if len(dataset_and_size[0]) < dataset_and_size[1]:
raise IndexError(
f"{type(dataset).__name__} blend goes out of bounds for {type([dataset_and_size[0]]).__name__} {i} for {dataset.split.name} split"
f"The {dataset.split.name} blend oversamples (N = {dataset_and_size[1]}) {type(dataset_and_size[0]).__name__} {i} (len = {len(dataset_and_size[0])}). "
f"Set renormalize_blend_weights to True and re-run. File an issue if the problem is not resolved."
)

return datasets
Expand Down Expand Up @@ -208,7 +209,10 @@ def _build_blended_dataset_splits(
if split[i] is not None:
weights_i = weights
if weights_i is not None and self.sizes[i] is not None:
size_i = sum(list(zip(*sizes_per_dataset))[i])
size_per_dataset = list(zip(*sizes_per_dataset))[i]
size_i = sum(size_per_dataset)
if self.config.renormalize_blend_weights:
weights_i = list(map(lambda _size: _size / size_i, size_per_dataset))
elif weights_i is None:
try:
weights_i = [
Expand Down Expand Up @@ -272,7 +276,10 @@ def _build_blended_dataset_splits(

# Build top-level dataset
if weights is not None and self.sizes[i] is not None:
size = list(map(sum, zip(*sizes_per_dataset)))[i]
size_per_dataset = list(zip(*sizes_per_dataset))[i]
size = sum(size_per_dataset)
if self.config.renormalize_blend_weights:
weights = list(map(lambda _size: _size / size, size_per_dataset))
elif weights is None:
try:
weights = [
Expand Down
9 changes: 7 additions & 2 deletions megatron/core/datasets/blended_megatron_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class BlendedMegatronDatasetConfig:
'blend'. Defauls to None.
"""

renormalize_blend_weights: bool = False
"""Renormalize the blend weights to account for mid-level dataset oversampling done to ensure
fulfillmenet of the of the requested number of samples. Defaults to False for backward
comparability in the data sample order.
"""

split: Optional[str] = None
"""The split string, a comma separated weighting for the dataset splits when drawing samples
from a single distribution. Not to be used with 'blend_per_split'. Defaults to None.
Expand Down Expand Up @@ -64,8 +70,7 @@ class BlendedMegatronDatasetConfig:
"""The MegatronTokenizer instance or None. Required for datasets which do online tokenization."""

def __post_init__(self) -> None:
"""Do asserts and set fields post init
"""
"""Do asserts and set fields post init"""
if self.blend_per_split is not None and any(self.blend_per_split):
assert self.blend is None, "blend and blend_per_split are incompatible"
assert self.split is None, "split and blend_per_split are incompatible"
Expand Down
5 changes: 5 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,11 @@ def _add_data_args(parser):
'(3) a list of prefixes e.g. prefix1 prefix2. '
'For (3), weights are inferred from the lengths of the contributing datasets. '
'This argument is exclusive to the other independent --*-data-path arguments.')
group.add_argument('--renormalize-blend-weights', action='store_true',
help='Renormalize the blend weights to account for the mid-level dataset '
'oversampling done to ensure fulfillment of the requested number of '
'samples. Use this option if prompted. Defaults to False for backward '
'comparability in the data sample order.')
group.add_argument('--split', type=str, default=None,
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
1 change: 1 addition & 0 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
Expand Down
1 change: 1 addition & 0 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def core_gpt_dataset_config_from_args(args):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
Expand Down
1 change: 1 addition & 0 deletions pretrain_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def core_gpt_dataset_config_from_args(args):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
Expand Down
1 change: 1 addition & 0 deletions pretrain_retro.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
split_preprocessing=retro_config.retro_split_preprocessing,
path_to_cache=args.data_cache_path,
Expand Down
1 change: 1 addition & 0 deletions pretrain_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
Expand Down
58 changes: 54 additions & 4 deletions tests/unit_tests/data/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend_per_split=[blends[Split.train], None, None,],
blend_per_split=[
blends[Split.train],
None,
None,
],
)
try:
datasets = BlendedMegatronDatasetBuilder(
Expand All @@ -123,7 +127,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend_per_split=[get_blend_from_list([paths[Split.train][0]]), None, None,],
blend_per_split=[
get_blend_from_list([paths[Split.train][0]]),
None,
None,
],
)
datasets = BlendedMegatronDatasetBuilder(
TestDataset, [1000, None, None], lambda: True, config
Expand Down Expand Up @@ -179,7 +187,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend_per_split=[blends_unweighted[Split.train], None, None,],
blend_per_split=[
blends_unweighted[Split.train],
None,
None,
],
)
datasets = BlendedMegatronDatasetBuilder(
TestDataset, [1000, None, None], lambda: True, config
Expand Down Expand Up @@ -219,7 +231,25 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend_per_split=[blends[Split.train], blends[Split.valid], blends[Split.test],],
blend_per_split=[blends[Split.train], None, None],
renormalize_blend_weights=True,
)
datasets = BlendedMegatronDatasetBuilder(
TestDataset, [1000, None, None], lambda: True, config
).build()
assert (
len(datasets[0]) >= 1000
and len(datasets[0]) <= 1000 * (1 + _MARGIN) + _NUM_DATASETS
)

config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend_per_split=[
blends[Split.train],
blends[Split.valid],
blends[Split.test],
],
)
datasets = BlendedMegatronDatasetBuilder(
TestDataset, [100, 100, 100], lambda: True, config
Expand Down Expand Up @@ -336,6 +366,26 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
# W = S / sum(S)
#
##
config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
blend=blends[Split.train],
split="990,9,1",
renormalize_blend_weights=True,
)
datasets = BlendedMegatronDatasetBuilder(
TestDataset, [100000, 1000, 1], lambda: True, config
).build()
assert (
len(datasets[0]) >= 100000
and len(datasets[0]) <= 100000 * (1 + _MARGIN) + _NUM_DATASETS
)
assert (
len(datasets[1]) >= 1000
and len(datasets[1]) <= 1000 * (1 + _MARGIN) + _NUM_DATASETS
)
assert len(datasets[2]) >= 1 and len(datasets[2]) <= 1 * (1 + _MARGIN) + _NUM_DATASETS

config = BlendedMegatronDatasetConfig(
random_seed=1234,
sequence_length=_SEQUENCE_LENGTH,
Expand Down
1 change: 1 addition & 0 deletions tools/retro/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def get_gpt_chunk_datasets(config):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=config.retro_gpt_split,
split_preprocessing=config.retro_gpt_split,
path_to_cache=config.retro_gpt_data_cache_path,
Expand Down

0 comments on commit 2c47ea2

Please sign in to comment.