diff --git a/docs/source/concept_guides/internal_mechanism.md b/docs/source/concept_guides/internal_mechanism.md index e0b715dfa63..e67dfba8ee5 100644 --- a/docs/source/concept_guides/internal_mechanism.md +++ b/docs/source/concept_guides/internal_mechanism.md @@ -28,7 +28,7 @@ Then, when calling [`~Accelerator.prepare`], the library: - wraps your model(s) in the container adapted for the distributed setup, - wraps your optimizer(s) in an [`~optimizer.AcceleratedOptimizer`], - wraps your scheduler(s) in an [`~scheduler.AcceleratedScheduler`] -- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`] or [`~data_loader.DataLoaderDispatcher`] +- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`], [`~data_loader.DataLoaderDispatcher`], or [`~data_loader.CustomTypesDataLoader`] While the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly because PyTorch does not let the user change the `batch_sampler` of a dataloader once it's been created and the @@ -42,7 +42,9 @@ The [`~data_loader.DataLoaderShard`] subclasses `DataLoader` to add the followin - it puts the batches on the proper device before yielding them (unless you have opted out of `device_placement=True`). -The [`~data_loader.DataLoaderDispatcher`] subclasses differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level. +The [`~data_loader.DataLoaderDispatcher`] subclass differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level. + +The [`~data_loader.CustomTypesDataLoader`] subclass differs from the [`~data_loader.DataLoaderShard`] and [`~data_loader.DataLoaderDispatcher`] in that it can be used to wrap iterables with custom logic to give users more control over how the dataloader should manipulate the data; the dataloader itself only invokes `__iter__()` and moves the data to the appropriate device. The random number generator synchronization will by default synchronize: diff --git a/docs/source/package_reference/torch_wrappers.md b/docs/source/package_reference/torch_wrappers.md index 17350e3441f..9d3cc251c2c 100644 --- a/docs/source/package_reference/torch_wrappers.md +++ b/docs/source/package_reference/torch_wrappers.md @@ -27,6 +27,7 @@ when calling [`~Accelerator.prepare`]. [[autodoc]] data_loader.IterableDatasetShard [[autodoc]] data_loader.DataLoaderShard [[autodoc]] data_loader.DataLoaderDispatcher +[[autodoc]] data_loader.CustomTypesDataLoader ## Optimizers diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 902b9f4dbc7..52e6e534e06 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -24,6 +24,7 @@ import sys import warnings from collections import OrderedDict +from collections.abc import Iterable from contextlib import contextmanager from functools import partial from types import MethodType @@ -1182,13 +1183,20 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): - return self.prepare_data_loader(obj, device_placement=device_placement) - elif isinstance(obj, torch.nn.Module): + if isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): optimizer = self.prepare_optimizer(obj, device_placement=device_placement) return optimizer + elif isinstance(obj, torch.utils.data.DataLoader) or ( + self.dataloader_config.custom_types and isinstance(obj, Iterable) + ): + return self.prepare_data_loader( + obj, + device_placement=device_placement, + custom_types=self.dataloader_config.custom_types, + custom_type_batch_size=self.dataloader_config.custom_type_batch_size, + ) # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, LRScheduler): scheduler = self.prepare_scheduler(obj) @@ -1990,15 +1998,21 @@ def _prepare_msamp(self, *args): return tuple(result) def prepare_data_loader( - self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None + self, + data_loader: Union[torch.utils.data.DataLoader, Iterable], + device_placement=None, + slice_fn_for_dispatch=None, + custom_types: bool = False, + custom_type_batch_size: int = None, ): """ Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use [`Accelerator.prepare`] instead. Args: - data_loader (`torch.utils.data.DataLoader`): - A vanilla PyTorch DataLoader to prepare + data_loader (`Union[torch.utils.data.DataLoader, Iterable]`): + A vanilla PyTorch DataLoader to prepare, or a custom Iterable to be wrapped in a dataloader if + `custom_types`=True. device_placement (`bool`, *optional*): Whether or not to place the batches on the proper device in the prepared dataloader. Will default to `self.device_placement`. @@ -2006,6 +2020,11 @@ def prepare_data_loader( If passed, this function will be used to slice tensors across `num_processes`. Will default to [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be ignored otherwise. + custom_types (`bool`, , *optional*`, defaults to `False`): + Whether or not the data_loader arg is a custom type, or a vanilla dataloader wrapped around an instance + of a custom iterable type. + custom_type_batch_size (`int`, *optional*): + Batch size to be used if custom_types=`True`. Example: @@ -2038,6 +2057,8 @@ def prepare_data_loader( slice_fn_for_dispatch=slice_fn_for_dispatch, use_seedable_sampler=self.use_seedable_sampler, non_blocking=self.non_blocking, + custom_types=custom_types, + custom_type_batch_size=custom_type_batch_size, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index f0e88c645ec..b5cbff43c96 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections.abc import Iterable from contextlib import suppress from typing import Callable, List, Optional, Union @@ -564,6 +565,69 @@ def dataloader(self): return self._loader +class CustomTypesDataLoader(DataLoader, DataLoaderStateMixin): + """ + Subclass of a PyTorch `DataLoader` that can handle custom iterable types as long as they yield objects that can be + converted to PyTorch tensors. + + Args: + data_or_loader (`Union[torch.utils.data.dataloader.DataLoader, Iterable]`): + The data or `DataLoader` wrapping an arbitrary iterable to be moved to the provided device. + batch_size (`int`, *optional*, defaults to `None`): + Batch size to be used in the created `DataLoader`. Note that if the object provided is already a + `DataLoader`, the `batch_size` attribute of that loader will be used. + device (`torch.device`): + The target device for the returned `DataLoader`. + _non_blocking (`bool`, *optional*, defaults to `False`): + If set to `True`, `DataLoader` will utilize non-blocking host-to-device transfers. If the `DataLoader` has + `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. + + Returns: + `torch.utils.data.dataloader.DataLoader`: A new data loader that will invoke `__iter__()` on the encapsulated + iterable and move yielded data to the provided device. + """ + + def __init__( + self, + data_or_loader: Union[DataLoader, Iterable], + batch_size: Optional[int] = None, + device: Optional[torch.device] = None, + _non_blocking: bool = False, + **kwargs, + ): + if isinstance(data_or_loader, DataLoader): + data = data_or_loader.dataset + if batch_size is not None and batch_size != data_or_loader.batch_size: + raise ValueError( + "Provided custom types batch size conflicts with the batch size of wrapped DataLoader" + ) + batch_size = data_or_loader.batch_size + else: + if batch_size is None: + raise ValueError("`custom_types` enabled, but `custom_type_batch_size` is None") + data = data_or_loader + self.device = device + self._non_blocking = _non_blocking + super().__init__(self._build_iterable_dataset(data), batch_size=batch_size) + + def _build_iterable_dataset(self, iter_type): + # If it's already an iterable dataset, we can don't need to do anything + if isinstance(iter_type, IterableDataset): + return iter_type + # If it isn't, we create a thin wrapper to make it into one + + class WrappedIterable(IterableDataset): + def __iter__(self): + return iter(iter_type) + + return WrappedIterable() + + def __iter__(self): + # Iterate through the data; if the device is configured, move the data to it + for batch in super().__iter__(): + yield (send_to_device(batch, self.device, non_blocking=self._non_blocking)) + + class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): """ Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each @@ -812,6 +876,8 @@ def prepare_data_loader( slice_fn_for_dispatch: Optional[Callable] = None, use_seedable_sampler: bool = False, non_blocking: bool = False, + custom_types: bool = False, + custom_type_batch_size: Optional[int] = None, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -873,7 +939,11 @@ def prepare_data_loader( non_blocking (`bool`, *optional*, defaults to `False`): If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. - + custom_types (`bool`, *optional*, defaults to `False`): + If set to `True`, dataloader will accept custom_types that yield values that can be converted to Torch + tensors. If `True`, the value of `dispatch_batches` and `split_batches` will be ignored. + custom_type_batch_size (`int`, *optional*): + Batch size to be used if custom_types=`True` for custom iterables not already wrapped by a dataloader. Returns: `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches @@ -885,6 +955,11 @@ def prepare_data_loader( """ + if custom_types: + return CustomTypesDataLoader( + dataloader, batch_size=custom_type_batch_size, non_blocking=non_blocking, device=device + ) + if dispatch_batches is None: if not put_on_device: dispatch_batches = False diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index cf41bc76b62..2d73d8c6e89 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -720,6 +720,20 @@ class DataLoaderConfiguration: " prepared dataloader has `pin_memory` set to `True` to work properly." }, ) + custom_types: bool = field( + default=False, + metadata={ + "help": "If set to `True`, the data prepared by the Accelerator may wrap custom iterables, as long as it" + " yields types that can be converted into torch tensors. If `True`, the values of `split_batches` and" + " `dispatch_batches` will not be used." + }, + ) + custom_type_batch_size: int = field( + default=None, + metadata={ + "help": "Value to be used for the batch size of wrapped custom iterables. Only used if `custom_types` is `True`." + }, + ) @dataclass diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2f360d71bcb..b36224cae33 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -15,11 +15,14 @@ import random import unittest +import pytest +from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset from accelerate import Accelerator from accelerate.data_loader import ( BatchSamplerShard, + CustomTypesDataLoader, DataLoaderDispatcher, DataLoaderShard, IterableDatasetShard, @@ -27,6 +30,8 @@ SkipDataLoader, skip_first_batches, ) +from accelerate.test_utils.testing import require_cuda +from accelerate.utils import DataLoaderConfiguration class RandomIterableDataset(IterableDataset): @@ -396,3 +401,83 @@ def test_end_of_dataloader_dispatcher(self): # Test it also works on the second iteration for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) + + @staticmethod + def _get_custom_iterable(data): + class MyCustomType: + def __init__(self): + self.data = data + + def __iter__(self): + return iter(self.data) + + return MyCustomType() + + @staticmethod + def check_custom_types_iterable(dataloader, expected_batches, device=None): + assert isinstance(dataloader, CustomTypesDataLoader) + assert len(expected_batches) == len(list(dataloader)) + for _ in range(2): + for batch, expected_batch in zip(dataloader, expected_batches): + # And that each time we get the expected tensor on the device we specified + assert batch.tolist() == expected_batch + if device is not None: + assert batch.device.type == device + + @parameterized.expand( + [ + ("nested under dataloader wrapper", True), + ("without nested dataloader wrapper", False), + ] + ) + @require_cuda + def test_custom_types_dataloader(self, _, wrap_with_dataloader): + device = "cuda" + custom_iterable = self._get_custom_iterable(data=list(range(8))) + if wrap_with_dataloader: + custom_iterable = DataLoader(custom_iterable, batch_size=4) + kwargs = {} + else: + kwargs = {"batch_size": 4} + dataloader = CustomTypesDataLoader(custom_iterable, device=device, **kwargs) + expected_batches = [[0, 1, 2, 3], [4, 5, 6, 7]] + self.check_custom_types_iterable(dataloader, expected_batches, device) + + @parameterized.expand( + [ + ("nested under dataloader wrapper", True), + ("without nested dataloader wrapper", False), + ] + ) + def test_custom_types_via_prepare(self, _, wrap_with_dataloader): + batch_size = 4 + dataloader_config = DataLoaderConfiguration(custom_types=True) + custom_iterable = self._get_custom_iterable(data=list(range(8))) + if wrap_with_dataloader: + # If it's a data loader, we pull the batch size off the dataloader + custom_iterable = DataLoader(custom_iterable, batch_size=batch_size) + else: + # Otherwise we need to specify it through the dataloader config + dataloader_config.custom_type_batch_size = batch_size + accelerator = Accelerator(dataloader_config=dataloader_config) + dataloader = accelerator.prepare(custom_iterable) + expected_batches = [[0, 1, 2, 3], [4, 5, 6, 7]] + self.check_custom_types_iterable(dataloader, expected_batches) + + def test_prepare_custom_types_dataloader_is_idempotent(self): + accelerator = Accelerator(dataloader_config=DataLoaderConfiguration(custom_types=True)) + custom_iterable = DataLoader(self._get_custom_iterable(data=list(range(8))), batch_size=4) + dataloader = CustomTypesDataLoader(custom_iterable) + prepared_dataloader = accelerator.prepare(dataloader) + assert isinstance(prepared_dataloader, CustomTypesDataLoader) + assert dataloader.dataset == prepared_dataloader.dataset + + def test_prepare_custom_types_dataloader_conflicting_batch_sizes(self): + # Ensure we can't pass a batch size for custom types and a wrapped + # dataloader unless the batch sizes are the same value + accelerator = Accelerator( + dataloader_config=DataLoaderConfiguration(custom_types=True, custom_type_batch_size=2) + ) + dataloader = DataLoader(self._get_custom_iterable(data=list(range(8))), batch_size=4) + with pytest.raises(ValueError): + accelerator.prepare(dataloader)