Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Types DataLoader #3008

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/source/concept_guides/internal_mechanism.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Then, when calling [`~Accelerator.prepare`], the library:
- wraps your model(s) in the container adapted for the distributed setup,
- wraps your optimizer(s) in an [`~optimizer.AcceleratedOptimizer`],
- wraps your scheduler(s) in an [`~scheduler.AcceleratedScheduler`]
- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`] or [`~data_loader.DataLoaderDispatcher`]
- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`], [`~data_loader.DataLoaderDispatcher`], or [`~data_loader.CustomTypesDataLoader`]

While the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly
because PyTorch does not let the user change the `batch_sampler` of a dataloader once it's been created and the
Expand All @@ -42,7 +42,9 @@ The [`~data_loader.DataLoaderShard`] subclasses `DataLoader` to add the followin
- it puts the batches on the proper device before yielding them (unless you have opted out of
`device_placement=True`).

The [`~data_loader.DataLoaderDispatcher`] subclasses differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level.
The [`~data_loader.DataLoaderDispatcher`] subclass differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level.

The [`~data_loader.CustomTypesDataLoader`] subclass differs from the [`~data_loader.DataLoaderShard`] and [`~data_loader.DataLoaderDispatcher`] in that it can be used to wrap iterables with custom logic to give users more control over how the dataloader should manipulate the data; the dataloader itself only invokes `__iter__()` and moves the data to the appropriate device.

The random number generator synchronization will by default synchronize:

Expand Down
1 change: 1 addition & 0 deletions docs/source/package_reference/torch_wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ when calling [`~Accelerator.prepare`].
[[autodoc]] data_loader.IterableDatasetShard
[[autodoc]] data_loader.DataLoaderShard
[[autodoc]] data_loader.DataLoaderDispatcher
[[autodoc]] data_loader.CustomTypesDataLoader

## Optimizers

Expand Down
33 changes: 27 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import warnings
from collections import OrderedDict
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial
from types import MethodType
Expand Down Expand Up @@ -1182,13 +1183,20 @@ def print(self, *args, **kwargs):
def _prepare_one(self, obj, first_pass=False, device_placement=None):
# First pass of preparation: DataLoader, model, optimizer
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj, device_placement=device_placement)
elif isinstance(obj, torch.nn.Module):
if isinstance(obj, torch.nn.Module):
return self.prepare_model(obj, device_placement=device_placement)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
return optimizer
elif isinstance(obj, torch.utils.data.DataLoader) or (
self.dataloader_config.custom_types and isinstance(obj, Iterable)
):
return self.prepare_data_loader(
obj,
device_placement=device_placement,
custom_types=self.dataloader_config.custom_types,
custom_type_batch_size=self.dataloader_config.custom_type_batch_size,
)
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, LRScheduler):
scheduler = self.prepare_scheduler(obj)
Expand Down Expand Up @@ -1990,22 +1998,33 @@ def _prepare_msamp(self, *args):
return tuple(result)

def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
self,
data_loader: Union[torch.utils.data.DataLoader, Iterable],
device_placement=None,
slice_fn_for_dispatch=None,
custom_types: bool = False,
custom_type_batch_size: int = None,
):
"""
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.

Args:
data_loader (`torch.utils.data.DataLoader`):
A vanilla PyTorch DataLoader to prepare
data_loader (`Union[torch.utils.data.DataLoader, Iterable]`):
A vanilla PyTorch DataLoader to prepare, or a custom Iterable to be wrapped in a dataloader if
`custom_types`=True.
device_placement (`bool`, *optional*):
Whether or not to place the batches on the proper device in the prepared dataloader. Will default to
`self.device_placement`.
slice_fn_for_dispatch (`Callable`, *optional*`):
If passed, this function will be used to slice tensors across `num_processes`. Will default to
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will
be ignored otherwise.
custom_types (`bool`, , *optional*`, defaults to `False`):
Whether or not the data_loader arg is a custom type, or a vanilla dataloader wrapped around an instance
of a custom iterable type.
custom_type_batch_size (`int`, *optional*):
Batch size to be used if custom_types=`True`.

Example:

Expand Down Expand Up @@ -2038,6 +2057,8 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
custom_types=custom_types,
custom_type_batch_size=custom_type_batch_size,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
77 changes: 76 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
from collections.abc import Iterable
from contextlib import suppress
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -564,6 +565,69 @@ def dataloader(self):
return self._loader


class CustomTypesDataLoader(DataLoader, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that can handle custom iterable types as long as they yield objects that can be
converted to PyTorch tensors.

Args:
data_or_loader (`Union[torch.utils.data.dataloader.DataLoader, Iterable]`):
The data or `DataLoader` wrapping an arbitrary iterable to be moved to the provided device.
batch_size (`int`, *optional*, defaults to `None`):
Batch size to be used in the created `DataLoader`. Note that if the object provided is already a
`DataLoader`, the `batch_size` attribute of that loader will be used.
device (`torch.device`):
The target device for the returned `DataLoader`.
_non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, `DataLoader` will utilize non-blocking host-to-device transfers. If the `DataLoader` has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.

Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will invoke `__iter__()` on the encapsulated
iterable and move yielded data to the provided device.
"""

def __init__(
self,
data_or_loader: Union[DataLoader, Iterable],
batch_size: Optional[int] = None,
device: Optional[torch.device] = None,
_non_blocking: bool = False,
**kwargs,
):
if isinstance(data_or_loader, DataLoader):
data = data_or_loader.dataset
if batch_size is not None and batch_size != data_or_loader.batch_size:
raise ValueError(
"Provided custom types batch size conflicts with the batch size of wrapped DataLoader"
)
batch_size = data_or_loader.batch_size
else:
if batch_size is None:
raise ValueError("`custom_types` enabled, but `custom_type_batch_size` is None")
data = data_or_loader
self.device = device
self._non_blocking = _non_blocking
super().__init__(self._build_iterable_dataset(data), batch_size=batch_size)

def _build_iterable_dataset(self, iter_type):
# If it's already an iterable dataset, we can don't need to do anything
if isinstance(iter_type, IterableDataset):
return iter_type
# If it isn't, we create a thin wrapper to make it into one

class WrappedIterable(IterableDataset):
def __iter__(self):
return iter(iter_type)

return WrappedIterable()

def __iter__(self):
# Iterate through the data; if the device is configured, move the data to it
for batch in super().__iter__():
yield (send_to_device(batch, self.device, non_blocking=self._non_blocking))


class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
Expand Down Expand Up @@ -812,6 +876,8 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
custom_types: bool = False,
custom_type_batch_size: Optional[int] = None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -873,7 +939,11 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.

custom_types (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will accept custom_types that yield values that can be converted to Torch
tensors. If `True`, the value of `dispatch_batches` and `split_batches` will be ignored.
custom_type_batch_size (`int`, *optional*):
Batch size to be used if custom_types=`True` for custom iterables not already wrapped by a dataloader.

Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
Expand All @@ -885,6 +955,11 @@ def prepare_data_loader(

</Tip>
"""
if custom_types:
return CustomTypesDataLoader(
dataloader, batch_size=custom_type_batch_size, non_blocking=non_blocking, device=device
)

if dispatch_batches is None:
if not put_on_device:
dispatch_batches = False
Expand Down
14 changes: 14 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,20 @@ class DataLoaderConfiguration:
" prepared dataloader has `pin_memory` set to `True` to work properly."
},
)
custom_types: bool = field(
default=False,
metadata={
"help": "If set to `True`, the data prepared by the Accelerator may wrap custom iterables, as long as it"
" yields types that can be converted into torch tensors. If `True`, the values of `split_batches` and"
" `dispatch_batches` will not be used."
},
)
custom_type_batch_size: int = field(
default=None,
metadata={
"help": "Value to be used for the batch size of wrapped custom iterables. Only used if `custom_types` is `True`."
},
)


@dataclass
Expand Down
85 changes: 85 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@
import random
import unittest

import pytest
from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from accelerate import Accelerator
from accelerate.data_loader import (
BatchSamplerShard,
CustomTypesDataLoader,
DataLoaderDispatcher,
DataLoaderShard,
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
skip_first_batches,
)
from accelerate.test_utils.testing import require_cuda
from accelerate.utils import DataLoaderConfiguration


class RandomIterableDataset(IterableDataset):
Expand Down Expand Up @@ -396,3 +401,83 @@ def test_end_of_dataloader_dispatcher(self):
# Test it also works on the second iteration
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)

@staticmethod
def _get_custom_iterable(data):
class MyCustomType:
def __init__(self):
self.data = data

def __iter__(self):
return iter(self.data)

return MyCustomType()

@staticmethod
def check_custom_types_iterable(dataloader, expected_batches, device=None):
assert isinstance(dataloader, CustomTypesDataLoader)
assert len(expected_batches) == len(list(dataloader))
for _ in range(2):
for batch, expected_batch in zip(dataloader, expected_batches):
# And that each time we get the expected tensor on the device we specified
assert batch.tolist() == expected_batch
if device is not None:
assert batch.device.type == device

@parameterized.expand(
[
("nested under dataloader wrapper", True),
("without nested dataloader wrapper", False),
]
)
@require_cuda
def test_custom_types_dataloader(self, _, wrap_with_dataloader):
device = "cuda"
custom_iterable = self._get_custom_iterable(data=list(range(8)))
if wrap_with_dataloader:
custom_iterable = DataLoader(custom_iterable, batch_size=4)
kwargs = {}
else:
kwargs = {"batch_size": 4}
dataloader = CustomTypesDataLoader(custom_iterable, device=device, **kwargs)
expected_batches = [[0, 1, 2, 3], [4, 5, 6, 7]]
self.check_custom_types_iterable(dataloader, expected_batches, device)

@parameterized.expand(
[
("nested under dataloader wrapper", True),
("without nested dataloader wrapper", False),
]
)
def test_custom_types_via_prepare(self, _, wrap_with_dataloader):
batch_size = 4
dataloader_config = DataLoaderConfiguration(custom_types=True)
custom_iterable = self._get_custom_iterable(data=list(range(8)))
if wrap_with_dataloader:
# If it's a data loader, we pull the batch size off the dataloader
custom_iterable = DataLoader(custom_iterable, batch_size=batch_size)
else:
# Otherwise we need to specify it through the dataloader config
dataloader_config.custom_type_batch_size = batch_size
accelerator = Accelerator(dataloader_config=dataloader_config)
dataloader = accelerator.prepare(custom_iterable)
expected_batches = [[0, 1, 2, 3], [4, 5, 6, 7]]
self.check_custom_types_iterable(dataloader, expected_batches)

def test_prepare_custom_types_dataloader_is_idempotent(self):
accelerator = Accelerator(dataloader_config=DataLoaderConfiguration(custom_types=True))
custom_iterable = DataLoader(self._get_custom_iterable(data=list(range(8))), batch_size=4)
dataloader = CustomTypesDataLoader(custom_iterable)
prepared_dataloader = accelerator.prepare(dataloader)
assert isinstance(prepared_dataloader, CustomTypesDataLoader)
assert dataloader.dataset == prepared_dataloader.dataset

def test_prepare_custom_types_dataloader_conflicting_batch_sizes(self):
# Ensure we can't pass a batch size for custom types and a wrapped
# dataloader unless the batch sizes are the same value
accelerator = Accelerator(
dataloader_config=DataLoaderConfiguration(custom_types=True, custom_type_batch_size=2)
)
dataloader = DataLoader(self._get_custom_iterable(data=list(range(8))), batch_size=4)
with pytest.raises(ValueError):
accelerator.prepare(dataloader)