Skip to content

Commit

Permalink
Update docs for CustomTypesDataLoader
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Aug 12, 2024
1 parent 3b6eb97 commit 2ed1d56
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
6 changes: 4 additions & 2 deletions docs/source/concept_guides/internal_mechanism.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Then, when calling [`~Accelerator.prepare`], the library:
- wraps your model(s) in the container adapted for the distributed setup,
- wraps your optimizer(s) in an [`~optimizer.AcceleratedOptimizer`],
- wraps your scheduler(s) in an [`~scheduler.AcceleratedScheduler`]
- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`] or [`~data_loader.DataLoaderDispatcher`]
- creates a new version of your dataloader(s) in a [`~data_loader.DataLoaderShard`], [`~data_loader.DataLoaderDispatcher`], or [`~data_loader.CustomTypesDataLoader`]

While the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly
because PyTorch does not let the user change the `batch_sampler` of a dataloader once it's been created and the
Expand All @@ -42,7 +42,9 @@ The [`~data_loader.DataLoaderShard`] subclasses `DataLoader` to add the followin
- it puts the batches on the proper device before yielding them (unless you have opted out of
`device_placement=True`).

The [`~data_loader.DataLoaderDispatcher`] subclasses differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level.
The [`~data_loader.DataLoaderDispatcher`] subclass differs from the [`~data_loader.DataLoaderShard`] in that when iterating through the `DataLoader`, the data is all starting from process 0 and *then* split and sent off to each process rather than it happening at the dataset level.

The [`~data_loader.CustomTypesDataLoader`] subclass differs from the [`~data_loader.DataLoaderShard`] and [`~data_loader.DataLoaderDispatcher`] in that it can be used to wrap iterables with custom logic to give users more control over how the dataloader should manipulate the data; the dataloader itself only invokes `__iter__()` and moves the data to the appropriate device.

The random number generator synchronization will by default synchronize:

Expand Down
32 changes: 28 additions & 4 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
from collections.abc import Iterable
from contextlib import suppress
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -566,21 +567,44 @@ def dataloader(self):

class CustomTypesDataLoader(DataLoader, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that can handle custom iterable types as long as they yield things that can be
Subclass of a PyTorch `DataLoader` that can handle custom iterable types as long as they yield objects that can be
converted to PyTorch tensors.
Args:
data_or_loader (`Union[torch.utils.data.dataloader.DataLoader, Iterable]`):
The data or `DataLoader` wrapping an arbitrary iterable to be moved to the provided device.
batch_size (`int`, *optional*, defaults to `None`):
Batch size to be used in the created `DataLoader`. Note that if the object provided is already a
`DataLoader`, the `batch_size` attribute of that loader will be used.
device (`torch.device`):
The target device for the returned `DataLoader`.
_non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, `DataLoader` will utilize non-blocking host-to-device transfers. If the `DataLoader` has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will invoke `__iter__()` on the encapsulated
iterable and move yielded data to the provided device.
"""

def __init__(self, data_or_loader, batch_size=None, device=None, _non_blocking: bool = False, **kwargs):
def __init__(
self,
data_or_loader: Union[DataLoader, Iterable],
batch_size: Optional[int] = None,
device: Optional[torch.device] = None,
_non_blocking: bool = False,
**kwargs,
):
if isinstance(data_or_loader, DataLoader):
data = data_or_loader.dataset
if batch_size is not None and batch_size != data_or_loader.batch_size:
raise ValueError(
"provided custom types batch size conflicts with the batch size of wrapped dataloader"
"Provided custom types batch size conflicts with the batch size of wrapped DataLoader"
)
batch_size = data_or_loader.batch_size
else:
if batch_size is None:
raise ValueError("custom_types enabled, but `custom_type_batch_size` is None")
raise ValueError("`custom_types` enabled, but `custom_type_batch_size` is None")
data = data_or_loader
self.device = device
self._non_blocking = _non_blocking
Expand Down

0 comments on commit 2ed1d56

Please sign in to comment.