Skip to content

Iterating a streaming dataset correctly in random order #8015

@unrealwill

Description

@unrealwill

Describe the bug

For training to work properly it is considered good practice in machine learning to sample the dataset randomly (uniformly). I am not sure how the user is expected to do it correctly when the creator of the dataset did a bad job.

https://huggingface.co/docs/hub/datasets-webdataset

 **Shuffle**

Generally, datasets in WebDataset formats are already shuffled and ready to feed to a DataLoader. But you can still reshuffle the data with WebDataset’s approximate shuffling.

In addition to shuffling the list of shards, WebDataset uses a buffer to shuffle a dataset without any cost to speed:

Let me expose the general problem with this specific dataset.
https://huggingface.co/datasets/jackyhate/text-to-image-2M

I'll only consider the single node, single worker case to highlight the problem, which is somewhat hidden and attenuated when training using multiple machines.

Here is the recommended way to load it on the main page

# copy pasted from https://huggingface.co/datasets/jackyhate/text-to-image-2M
num_shards = 46  # Number of webdataset tar files
urls = [base_url.format(i=i) for i in range(num_shards)]
dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
# Example of iterating through the dataset
for image in dataset:
    print(image)  # single image in row with associated columns
    break

The dataset is composed of 46 shards from multiple sources : Some shards have data generated with flux, and other have data generated from dall-e. There are usually 50000 images per file except for data_000034.tar and data_000046.tar which have less files because they are the end files of some version of the dataset.

Initially I thought you were using the webdataset library which has at least tried to think about the problem of shuffling the dataset (although poorly documented and with bad default behavior : See https://github.com/webdataset/webdataset/blob/e0953f9bba17b416d5792d5a263b171c266e78be/src/webdataset/mix.py )

But it seems that you are just iterating the shards in order and then iterating the examples of each shard in order.

def _generate_examples(self, tar_paths, tar_iterators):
image_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image)
]
audio_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
]
all_field_names = list(self.info.features.keys())
for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
for field_name in all_field_names:
if field_name not in example:
example[field_name] = None
for field_name in image_field_names + audio_field_names:
if example[field_name] is not None:
example[field_name] = {
"path": example["__key__"] + "." + field_name,
"bytes": example[field_name],
}
yield Key(tar_idx, example_idx), example

If I understand correctly the recommended way to shuffle the data as described in https://huggingface.co/docs/datasets/loading

ds = ds.shuffle(seed=42, buffer_size=10_000) # shuffles the shards order + uses a shuffle buffer

or shuffle with a buffer in the dataloader.

This is problematic because the dataset has not been initially shuffled by the creator of the dataset, which means that samples are highly correlated, for examples samples coming from the flux shard have a lot more details than those coming from the dall-e shard, and they are not mixed in the shuffle buffer, which result in the loss function having a periodic pattern by epoch typical of bad shuffling behavior.

Of course we could just reshuffle the datasets beforehand but these are huge and we would need to do it for every versions of the dataset.

The random mixing and round robin mixing which can be used in the webdataset library, can also be problematic when the shards are of various sizes : When we stopped at the expiration of the first shard, we never visit all the examples from the later portions of other shards. Also if some shards provenance dataset are not uniformly balanced, we have some period of time where the shard is not sampled, for example when the smaller shard has been exhausted, samples from its class won't be seen until the next epoch.

The proper way being sampling from shards so that they are all exhausted roughly at the same time, aka using a smaller sampling probability for smaller shards, but that necessitate to know the sizes of the shard beforehand, and if you also do per shard shuffling like webdataset library do, it needs to open all shards simultaneously and therefore require a lot of memory to load nshards*buffer_size.

Hopefully I have been clear enough, and the problem which occurs with webdataset and probably also occurs with every other streamable format that you handle can be made to work in a standardized, robust and bug-free way, as the types of bugs these generate are usually of the silent kind where training works better but it generalizes less well because the training used some correlation from the sequence of the dataset.

Steps to reproduce the bug

Observe periodic behavior of loss function when training a simple neural network on the dataset

Expected behavior

A behavior somewhat similar to a standard uniform shuffle

Environment info

not relevant

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions