Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for custom function for reducing the batch size #3071

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def should_reduce_batch_size(exception: Exception) -> bool:
return False


def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128, reduce_batch_size_fn: callable = None):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
Comment on lines +106 to 109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to add reduce_batch_size_fn to the docstring, describing what it should take in and return (a batch size and return a modified batch size?)

Expand Down Expand Up @@ -134,6 +134,11 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)

batch_size = starting_batch_size
if reduce_batch_size_fn is None:
def reduce_batch_size_fn():
nonlocal batch_size
batch_size = batch_size // 2
return batch_size

def decorator(*args, **kwargs):
nonlocal batch_size
Expand All @@ -154,7 +159,7 @@ def decorator(*args, **kwargs):
except Exception as e:
if should_reduce_batch_size(e):
clear_device_cache(garbage_collection=True)
batch_size //= 2
batch_size = reduce_batch_size_fn()
else:
raise

Expand Down
Loading