diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py index f262b05f27..f7883d9b14 100644 --- a/megatron/core/datasets/blended_dataset.py +++ b/megatron/core/datasets/blended_dataset.py @@ -74,6 +74,7 @@ def __init__( unique_identifiers["split"] = self.split.name unique_identifiers["weights"] = self.weights unique_identifiers["size"] = self.size + unique_identifiers["renormalize_blend_weights"] = self.config.renormalize_blend_weights self.unique_description = json.dumps( unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py index baa87ae925..0230faf5e0 100644 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ b/megatron/core/datasets/blended_megatron_dataset_builder.py @@ -150,7 +150,8 @@ def build(self) -> List[Optional[TopLevelDataset]]: for i, dataset_and_size in enumerate(zip(dataset.datasets, sizes)): if len(dataset_and_size[0]) < dataset_and_size[1]: raise IndexError( - f"{type(dataset).__name__} blend goes out of bounds for {type([dataset_and_size[0]]).__name__} {i} for {dataset.split.name} split" + f"The {dataset.split.name} blend oversamples (N = {dataset_and_size[1]}) {type(dataset_and_size[0]).__name__} {i} (len = {len(dataset_and_size[0])}). " + f"Set renormalize_blend_weights to True and re-run. File an issue if the problem is not resolved." ) return datasets @@ -208,7 +209,10 @@ def _build_blended_dataset_splits( if split[i] is not None: weights_i = weights if weights_i is not None and self.sizes[i] is not None: - size_i = sum(list(zip(*sizes_per_dataset))[i]) + size_per_dataset = list(zip(*sizes_per_dataset))[i] + size_i = sum(size_per_dataset) + if self.config.renormalize_blend_weights: + weights_i = list(map(lambda _size: _size / size_i, size_per_dataset)) elif weights_i is None: try: weights_i = [ @@ -272,7 +276,10 @@ def _build_blended_dataset_splits( # Build top-level dataset if weights is not None and self.sizes[i] is not None: - size = list(map(sum, zip(*sizes_per_dataset)))[i] + size_per_dataset = list(zip(*sizes_per_dataset))[i] + size = sum(size_per_dataset) + if self.config.renormalize_blend_weights: + weights = list(map(lambda _size: _size / size, size_per_dataset)) elif weights is None: try: weights = [ diff --git a/megatron/core/datasets/blended_megatron_dataset_config.py b/megatron/core/datasets/blended_megatron_dataset_config.py index 10cd5909b9..52bc31f62e 100644 --- a/megatron/core/datasets/blended_megatron_dataset_config.py +++ b/megatron/core/datasets/blended_megatron_dataset_config.py @@ -34,6 +34,12 @@ class BlendedMegatronDatasetConfig: 'blend'. Defauls to None. """ + renormalize_blend_weights: bool = False + """Renormalize the blend weights to account for mid-level dataset oversampling done to ensure + fulfillmenet of the of the requested number of samples. Defaults to False for backward + comparability in the data sample order. + """ + split: Optional[str] = None """The split string, a comma separated weighting for the dataset splits when drawing samples from a single distribution. Not to be used with 'blend_per_split'. Defaults to None. @@ -64,8 +70,7 @@ class BlendedMegatronDatasetConfig: """The MegatronTokenizer instance or None. Required for datasets which do online tokenization.""" def __post_init__(self) -> None: - """Do asserts and set fields post init - """ + """Do asserts and set fields post init""" if self.blend_per_split is not None and any(self.blend_per_split): assert self.blend is None, "blend and blend_per_split are incompatible" assert self.split is None, "split and blend_per_split are incompatible" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b252723a55..188e9873a1 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1559,6 +1559,11 @@ def _add_data_args(parser): '(3) a list of prefixes e.g. prefix1 prefix2. ' 'For (3), weights are inferred from the lengths of the contributing datasets. ' 'This argument is exclusive to the other independent --*-data-path arguments.') + group.add_argument('--renormalize-blend-weights', action='store_true', + help='Renormalize the blend weights to account for the mid-level dataset ' + 'oversampling done to ensure fulfillment of the requested number of ' + 'samples. Use this option if prompted. Defaults to False for backward ' + 'comparability in the data sample order.') group.add_argument('--split', type=str, default=None, help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' diff --git a/pretrain_bert.py b/pretrain_bert.py index f5c553029c..35884ecdc4 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -154,6 +154,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=args.split, path_to_cache=args.data_cache_path, tokenizer=tokenizer, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 949f1571c7..9658e0700f 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -195,6 +195,7 @@ def core_gpt_dataset_config_from_args(args): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=args.split, num_dataset_builder_threads=args.num_dataset_builder_threads, path_to_cache=args.data_cache_path, diff --git a/pretrain_mamba.py b/pretrain_mamba.py index f2dbb97e67..9132ce2c62 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -186,6 +186,7 @@ def core_gpt_dataset_config_from_args(args): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=args.split, num_dataset_builder_threads=args.num_dataset_builder_threads, path_to_cache=args.data_cache_path, diff --git a/pretrain_retro.py b/pretrain_retro.py index a0d8f9d922..0aecbf14ce 100644 --- a/pretrain_retro.py +++ b/pretrain_retro.py @@ -189,6 +189,7 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=args.split, split_preprocessing=retro_config.retro_split_preprocessing, path_to_cache=args.data_cache_path, diff --git a/pretrain_t5.py b/pretrain_t5.py index d3960cbd32..69cbc0d5f2 100644 --- a/pretrain_t5.py +++ b/pretrain_t5.py @@ -213,6 +213,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=args.split, path_to_cache=args.data_cache_path, tokenizer=tokenizer, diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py index 141c67b31d..8f149dcffb 100644 --- a/tests/unit_tests/data/test_builder.py +++ b/tests/unit_tests/data/test_builder.py @@ -110,7 +110,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends[Split.train], None, None,], + blend_per_split=[ + blends[Split.train], + None, + None, + ], ) try: datasets = BlendedMegatronDatasetBuilder( @@ -123,7 +127,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[get_blend_from_list([paths[Split.train][0]]), None, None,], + blend_per_split=[ + get_blend_from_list([paths[Split.train][0]]), + None, + None, + ], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [1000, None, None], lambda: True, config @@ -179,7 +187,11 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends_unweighted[Split.train], None, None,], + blend_per_split=[ + blends_unweighted[Split.train], + None, + None, + ], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [1000, None, None], lambda: True, config @@ -219,7 +231,25 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, - blend_per_split=[blends[Split.train], blends[Split.valid], blends[Split.test],], + blend_per_split=[blends[Split.train], None, None], + renormalize_blend_weights=True, + ) + datasets = BlendedMegatronDatasetBuilder( + TestDataset, [1000, None, None], lambda: True, config + ).build() + assert ( + len(datasets[0]) >= 1000 + and len(datasets[0]) <= 1000 * (1 + _MARGIN) + _NUM_DATASETS + ) + + config = BlendedMegatronDatasetConfig( + random_seed=1234, + sequence_length=_SEQUENCE_LENGTH, + blend_per_split=[ + blends[Split.train], + blends[Split.valid], + blends[Split.test], + ], ) datasets = BlendedMegatronDatasetBuilder( TestDataset, [100, 100, 100], lambda: True, config @@ -336,6 +366,26 @@ def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: # W = S / sum(S) # ## + config = BlendedMegatronDatasetConfig( + random_seed=1234, + sequence_length=_SEQUENCE_LENGTH, + blend=blends[Split.train], + split="990,9,1", + renormalize_blend_weights=True, + ) + datasets = BlendedMegatronDatasetBuilder( + TestDataset, [100000, 1000, 1], lambda: True, config + ).build() + assert ( + len(datasets[0]) >= 100000 + and len(datasets[0]) <= 100000 * (1 + _MARGIN) + _NUM_DATASETS + ) + assert ( + len(datasets[1]) >= 1000 + and len(datasets[1]) <= 1000 * (1 + _MARGIN) + _NUM_DATASETS + ) + assert len(datasets[2]) >= 1 and len(datasets[2]) <= 1 * (1 + _MARGIN) + _NUM_DATASETS + config = BlendedMegatronDatasetConfig( random_seed=1234, sequence_length=_SEQUENCE_LENGTH, diff --git a/tools/retro/preprocess_data.py b/tools/retro/preprocess_data.py index dd36eb0667..444a64e584 100644 --- a/tools/retro/preprocess_data.py +++ b/tools/retro/preprocess_data.py @@ -110,6 +110,7 @@ def get_gpt_chunk_datasets(config): get_blend_from_list(args.valid_data_path), get_blend_from_list(args.test_data_path) ], + renormalize_blend_weights=args.renormalize_blend_weights, split=config.retro_gpt_split, split_preprocessing=config.retro_gpt_split, path_to_cache=config.retro_gpt_data_cache_path,