-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Describe the bug
For training to work properly it is considered good practice in machine learning to sample the dataset randomly (uniformly). I am not sure how the user is expected to do it correctly when the creator of the dataset did a bad job.
https://huggingface.co/docs/hub/datasets-webdataset
**Shuffle**
Generally, datasets in WebDataset formats are already shuffled and ready to feed to a DataLoader. But you can still reshuffle the data with WebDataset’s approximate shuffling.
In addition to shuffling the list of shards, WebDataset uses a buffer to shuffle a dataset without any cost to speed:
Let me expose the general problem with this specific dataset.
https://huggingface.co/datasets/jackyhate/text-to-image-2M
I'll only consider the single node, single worker case to highlight the problem, which is somewhat hidden and attenuated when training using multiple machines.
Here is the recommended way to load it on the main page
# copy pasted from https://huggingface.co/datasets/jackyhate/text-to-image-2M
num_shards = 46 # Number of webdataset tar files
urls = [base_url.format(i=i) for i in range(num_shards)]
dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
# Example of iterating through the dataset
for image in dataset:
print(image) # single image in row with associated columns
break
The dataset is composed of 46 shards from multiple sources : Some shards have data generated with flux, and other have data generated from dall-e. There are usually 50000 images per file except for data_000034.tar and data_000046.tar which have less files because they are the end files of some version of the dataset.
Initially I thought you were using the webdataset library which has at least tried to think about the problem of shuffling the dataset (although poorly documented and with bad default behavior : See https://github.com/webdataset/webdataset/blob/e0953f9bba17b416d5792d5a263b171c266e78be/src/webdataset/mix.py )
But it seems that you are just iterating the shards in order and then iterating the examples of each shard in order.
datasets/src/datasets/packaged_modules/webdataset/webdataset.py
Lines 111 to 130 in fe7353a
| def _generate_examples(self, tar_paths, tar_iterators): | |
| image_field_names = [ | |
| field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image) | |
| ] | |
| audio_field_names = [ | |
| field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio) | |
| ] | |
| all_field_names = list(self.info.features.keys()) | |
| for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)): | |
| for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)): | |
| for field_name in all_field_names: | |
| if field_name not in example: | |
| example[field_name] = None | |
| for field_name in image_field_names + audio_field_names: | |
| if example[field_name] is not None: | |
| example[field_name] = { | |
| "path": example["__key__"] + "." + field_name, | |
| "bytes": example[field_name], | |
| } | |
| yield Key(tar_idx, example_idx), example |
If I understand correctly the recommended way to shuffle the data as described in https://huggingface.co/docs/datasets/loading
ds = ds.shuffle(seed=42, buffer_size=10_000) # shuffles the shards order + uses a shuffle buffer
or shuffle with a buffer in the dataloader.
This is problematic because the dataset has not been initially shuffled by the creator of the dataset, which means that samples are highly correlated, for examples samples coming from the flux shard have a lot more details than those coming from the dall-e shard, and they are not mixed in the shuffle buffer, which result in the loss function having a periodic pattern by epoch typical of bad shuffling behavior.
Of course we could just reshuffle the datasets beforehand but these are huge and we would need to do it for every versions of the dataset.
The random mixing and round robin mixing which can be used in the webdataset library, can also be problematic when the shards are of various sizes : When we stopped at the expiration of the first shard, we never visit all the examples from the later portions of other shards. Also if some shards provenance dataset are not uniformly balanced, we have some period of time where the shard is not sampled, for example when the smaller shard has been exhausted, samples from its class won't be seen until the next epoch.
The proper way being sampling from shards so that they are all exhausted roughly at the same time, aka using a smaller sampling probability for smaller shards, but that necessitate to know the sizes of the shard beforehand, and if you also do per shard shuffling like webdataset library do, it needs to open all shards simultaneously and therefore require a lot of memory to load nshards*buffer_size.
Hopefully I have been clear enough, and the problem which occurs with webdataset and probably also occurs with every other streamable format that you handle can be made to work in a standardized, robust and bug-free way, as the types of bugs these generate are usually of the silent kind where training works better but it generalizes less well because the training used some correlation from the sequence of the dataset.
Steps to reproduce the bug
Observe periodic behavior of loss function when training a simple neural network on the dataset
Expected behavior
A behavior somewhat similar to a standard uniform shuffle
Environment info
not relevant