Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Types DataLoader #3008

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

alex-jw-brooks
Copy link
Contributor

What does this PR do?

Fixes #2975

This PR adds a CustomTypesDataLoader, which allows for the passing of custom iterable types (either under a PyTorch DataLoader, which would normally throw a TypeError once you start to iterate over it, or directly) to the Accelerator state's device. This adds two args to the DataLoaderConfiguration class:

  • custom_types: bool; indicates whether or not an instance of this class should be created when prepare() is invoked on a dataloader or iterable type
  • custom_type_batch_size: int; batch size to be used for the CustomTypesDataLoader; if the iterable is already held under a PyTorch dataloader, this is optional, and the batch size will be pulled off of the data loader.

Minimal example:

from accelerate import Accelerator
from accelerate.utils import DataLoaderConfiguration

class MyIterableType:
    def __init__(self, data):
        self.data = data
    
    def __iter__(self):
        return iter(self.data)

some_iterable = MyIterableType(list(range(16)))
accelerator = Accelerator(dataloader_config=DataLoaderConfiguration(custom_types=True, custom_type_batch_size=4))

# Passing DataLoader(some_iterable, batch_size=4) 
# would also be fine - if passed in this way, we could omit custom_type_batch_size
custom_type_loader = accelerator.prepare(some_iterable)
for batch in custom_type_loader:
    print(batch)

Running with export ACCELERATE_TORCH_DEVICE=cuda provides the following output:

tensor([0, 1, 2, 3], device='cuda:0')
tensor([4, 5, 6, 7], device='cuda:0')
tensor([ 8,  9, 10, 11], device='cuda:0')
tensor([12, 13, 14, 15], device='cuda:0')

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr @BenjaminBossan @SunMarc

@alex-jw-brooks alex-jw-brooks force-pushed the custom_types_dl branch 2 times, most recently from 04187eb to 2ed1d56 Compare August 12, 2024 21:43
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

To me, it's still not clear what we want to achieve here, even after reading the associated issue. Is the plan really to allow users to pass arbitrary iterables to accelerator.prepare (in which case the docstring of prepare needs to be updated)? Or is it that users still need to pass some type of DataLoader instance but that accelerate should handle this data loader instance better for arbitrary iterables inside that data loader?

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Aug 13, 2024

Good questions @BenjaminBossan! To be honest, after implementing it, I have mixed feelings too. My thoughts are:

  • If the goal is to sidestep the TypeError that would normally be thrown when iterating over a DataLoader and handle the device placement - this does do that, but the way that makes sense with the current Accelerate API is a bit strange, i.e., wrapping it in a DataLoader that normally would not be usable, then using prepare to get one that is feels like an antipattern to me

  • For passing arbitrary iterables to accelerator.prepare() - I have mixed feelings about that as well. I'm not entirely sure that the benefit this DataLoader adds is worth extending the API that way, and it's also kind of awkward to handle in the way _prepare_one() gets called - there are probably some weird edge-cases there that are hard to cleanly handle, e.g., if custom_types=True and the user happens to pass multiple iterables, since as things are now, those objects would just get returned 🤔

Those are mostly the reasons it's currently implemented the way it is, where you can pass it either as an iterable or as a wrapped dataloader - I think that it is interesting and could imagine the class itself being useful somewhere, but either way feels a bit awkward to expose with respect to the current accelerator API.

I'd be really interesting in hearing @muellerzr's thought as well. In any case, if there is a common enough use-case that it's compelling to add this feature that we can elucidate, maybe it would also be useful to add a demo somewhere showing how it's intended to be leveraged, either as part of this PR or a follow-up!

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Aug 14, 2024

From an API perspective, what might make more sense would be something like:

  • Split the data loader into two, where one handles the custom types part, and the other literally just takes a dataset and puts the yielded things onto the device
  • Separate the custom types one from the accelerator PR entirely, and just have the accelerator preparation take care of device placement

For example...

...
from accelerate.data_loader import CustomTypesDataLoader

class MyIterableType:
    def __init__(self, data):
        self.data = data
    
    def __iter__(self):
        return iter(self.data)


# Need to think about what the dataloaderconfiguration would be here still
accelerator = Accelerator(...)

some_iterable = MyIterableType(list(range(16)))
# Wraps as an `iterabledataset`, does not handle device placement at all
# If you want to use this, you need to build it yourself and pass it to prepare();
# you can't make it through prepare() with any dataloaderconfig
custom_types_loader = CustomTypesDataLoader(some_iterable)

# Gives a DataLoaderDeviceMover or something whose only function is to get stuff
# from the provided dataloader and put it on the device
dl = accelerator.prepare(custom_types_loader)

This would preserve the capability to do the same thing without expanding the dataloaderconfig/prepare() in a weird direction, reduces the likelihood of bugs with how _prepare_one() gets called, and also avoids potentially passing invalid dataloaders to prepare, etc

@muellerzr
Copy link
Collaborator

muellerzr commented Aug 16, 2024

Hi! Thanks so much for your first attempt at this, I had to trace back through the issues to figure out what really I was after with this "generic only place on device" dataloader, and it was from this FR for WebLoader: #2083

The general idea at the time was to support a generic iterable (which would be a dataloader type class) which we can then pipe into custom implementations easier such as WebLoader.

IMO it should just place the outputs on the device, and have all the mixins our dataloaders have, to make it work with things like gradient accumulation properly (since that's how we check for it).

Let me know if you want to rethink your PR/simplify it at all/etc given this new information.

It's a great effort, and I agree I did not really expound too much on the FR issue at the time. Let me know if this path forward makes sense to you both and we can continue 🚀

(I'm not 100% certain if doing sharding makes sense for dataloaders like this, hence the "Just move to device and make sure gradient accumulation doesn't break")

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Aug 20, 2024

Thanks for the guidance @muellerzr! It makes sense to try to rework this PR a bit so that we can discuss further - I think assuming that everything that should be prepared as a DataLoader is a DataLoader is a good idea. Then maybe something like the workaround @BenjaminBossan suggested in the webloader issue to debug, e.g.,

class MyLoader(wds.WebLoader, torch.utils.data.DataLoader):
    pass

train_dataloader = MyLoader(train_dataset, ...)
accelerator = Accelerator(device_placement=True) # maybe another flag is needed here? will check current behavior
dl = accelerator.prepare(train_dataloader)

could be a recommended solution for data-loader like classes that aren't instances, which is a bit gross looking, but is at least gross looking outside of accelerate. If I have the bandwidth, I'll try to reproduce the webloader issue to see what it would take for the revisions in this PR is able to fix it, it would be nice to validate its usefulness 🤞

Copy link

github-actions bot commented Oct 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Oct 16, 2024
@SunMarc SunMarc reopened this Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants