Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: multiple model/configuration DeepSpeed support #3097

Merged
merged 65 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
b8b77e7
Bookmark
muellerzr Sep 9, 2024
602f6f6
Migratory
muellerzr Sep 10, 2024
0dd96a1
Uncomment
muellerzr Sep 10, 2024
8db38da
Rm name to model for now
muellerzr Sep 10, 2024
2f8c5dd
Rm container
muellerzr Sep 10, 2024
4b1bc15
Left: test
muellerzr Sep 10, 2024
3488b11
Allow only wrapping one model
muellerzr Sep 10, 2024
3a42fcd
Add warning but only ref once
muellerzr Sep 10, 2024
26b3c2b
Refine
muellerzr Sep 10, 2024
a059e36
Update src/accelerate/accelerator.py
muellerzr Sep 10, 2024
8a3e628
Finish stas nits
muellerzr Sep 10, 2024
bcfcf55
Merge branch 'muellerzr-multiple-model-deepspeed' of https://github.c…
muellerzr Sep 10, 2024
599986c
Merge branch 'main' into muellerzr-multiple-model-deepspeed
muellerzr Sep 10, 2024
447091b
Clean
muellerzr Sep 10, 2024
7049f6c
Fixup test + test writing
muellerzr Sep 10, 2024
9ff5296
Fully working
muellerzr Sep 10, 2024
db37ea2
Fin
muellerzr Sep 10, 2024
5ebbae7
Nit
muellerzr Sep 10, 2024
fb53edd
Quality
muellerzr Sep 10, 2024
d1078b8
Update src/accelerate/accelerator.py
muellerzr Sep 10, 2024
d01aea4
Actionable error
muellerzr Sep 10, 2024
082f8ec
Merge branch 'muellerzr-multiple-model-deepspeed' of https://github.c…
muellerzr Sep 10, 2024
4f8adfd
Make note of when its enabled
muellerzr Sep 10, 2024
1c6d50d
Apply suggestions from code review
muellerzr Sep 10, 2024
faef63b
Merge tests
muellerzr Sep 10, 2024
5a593df
Merge branch 'muellerzr-multiple-model-deepspeed' of https://github.c…
muellerzr Sep 10, 2024
be80f87
Merge
muellerzr Sep 10, 2024
f054684
Add currently broken test script
muellerzr Sep 11, 2024
a191f3b
Push the working implementation
muellerzr Sep 11, 2024
844355b
Fin
muellerzr Sep 11, 2024
87dba32
Add guards for user behavior
muellerzr Sep 11, 2024
fa462ff
Test nits
muellerzr Sep 11, 2024
fa291b3
TODO: finish knowledge distillation example
muellerzr Sep 11, 2024
4593c52
Update tests/deepspeed/test_deepspeed_multiple_model.py
muellerzr Sep 11, 2024
3ff05bf
Allow for dict-like interface
muellerzr Sep 11, 2024
d99a431
Merge branch 'muellerzr-multiple-model-deepspeed' of https://github.c…
muellerzr Sep 11, 2024
6a804c9
Get rid of disable
muellerzr Sep 11, 2024
1d81c4c
Uncomment
muellerzr Sep 11, 2024
6abff04
Complete rewrite to force a dict to be used
muellerzr Sep 12, 2024
7b703b6
Working tests/fin
muellerzr Sep 12, 2024
f1939f5
Use name as stas suggestion
muellerzr Sep 12, 2024
96bf3f9
Clean
muellerzr Sep 12, 2024
c7c6ccc
docnit
muellerzr Sep 12, 2024
2def7aa
toctree
muellerzr Sep 12, 2024
0df2940
toctree
muellerzr Sep 12, 2024
43d9def
Missing ref
muellerzr Sep 12, 2024
0eae50c
Put in break
muellerzr Sep 12, 2024
718e5ef
Smaller diff
muellerzr Sep 12, 2024
a079267
Make note on how to use zeroinit
muellerzr Sep 12, 2024
2e2f0dc
Make note about accelerator ds plugin
muellerzr Sep 12, 2024
7d4e712
More docnits
muellerzr Sep 12, 2024
837272f
Apply suggestions from code review
muellerzr Sep 12, 2024
a17e03e
Limit users to not pass in another ds plugin to another accelerator
muellerzr Sep 12, 2024
3e35dd3
not implemented err + Make a note about why no params
muellerzr Sep 12, 2024
08c2c5f
Apply suggestions from code review from Stas
muellerzr Sep 12, 2024
f889b88
Add deepspeed_plugins arg + update doc
muellerzr Sep 12, 2024
91ab419
Plugin -> plugins
muellerzr Sep 12, 2024
70bc728
Change enable() -> select()
muellerzr Sep 12, 2024
f193ece
Update ref properly + test
muellerzr Sep 12, 2024
2375a0a
Conflict
muellerzr Sep 12, 2024
c9baf0d
Be consistent, model1,model2...
muellerzr Sep 12, 2024
6d66b92
first_, second_
muellerzr Sep 12, 2024
254a3ef
A few more auto values
muellerzr Sep 12, 2024
88cbbdf
Apply suggestions from code review
muellerzr Sep 12, 2024
dd77813
Apply suggestions from code review
muellerzr Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 51 additions & 18 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
get_active_deepspeed_plugin,
)

if is_megatron_lm_available():
Expand Down Expand Up @@ -179,9 +180,10 @@ class Accelerator:
the execution on one process only.
dataloader_config (`DataLoaderConfiguration`, *optional*):
A configuration for how the dataloaders should be handled in distributed scenarios.
deepspeed_plugin ([`~utils.DeepSpeedPlugin`], *optional*):
deepspeed_plugin ([`~utils.DeepSpeedPlugin`] or list of [`~utils.DeepSpeedPlugin`], *optional*):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
directly using *accelerate config*. If using multiple plugins, the first one will be the active one by
default. Manually call `plugin.enable()` to activate a different plugin.
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
Expand Down Expand Up @@ -285,11 +287,17 @@ def __init__(
DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None
)
else:
assert isinstance(
deepspeed_plugin, DeepSpeedPlugin
), "`deepspeed_plugin` must be an `accelerate.utils.DeepSpeedPlugin` object."
if isinstance(deepspeed_plugin, (tuple, list)):
for plugin in deepspeed_plugin:
if not isinstance(plugin, DeepSpeedPlugin):
raise TypeError("`deepspeed_plugin` must be a DeepSpeedPlugin object.")
elif isinstance(deepspeed_plugin, DeepSpeedPlugin):
deepspeed_plugin = [deepspeed_plugin]
else:
raise TypeError("`deepspeed_plugin` must be a DeepSpeedPlugin object.")

if deepspeed_plugin is not None:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
if deepspeed_plugin:
if not is_deepspeed_available():
raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
if is_mlu_available():
Expand All @@ -304,8 +312,11 @@ def __init__(
mixed_precision = (
os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()
for plugin in deepspeed_plugin:
plugin.set_mixed_precision(mixed_precision)
# The first plugin is always the active one
deepspeed_plugin[0].enable()
self.deepspeed_engine_wrapped = None

if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
fsdp_plugin, FullyShardedDataParallelPlugin
Expand Down Expand Up @@ -541,6 +552,23 @@ def __init__(

check_os_kernel()

@property
def deepspeed_plugin(self):
"""
Returns the currently active DeepSpeedPlugin.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

If using multiple plugins, the first one will be the active one by default. Manually call `plugin.enable()` to
activate a different plugin.
"""
return self.state.deepspeed_plugin

@property
def deepspeed_plugins(self):
"""
Returns all of the DeepSpeedPlugins
"""
return self.state.deepspeed_plugins

@property
def use_distributed(self):
"""
Expand Down Expand Up @@ -1640,7 +1668,7 @@ def _prepare_deepspeed(self, *args):

ds_initialize = msamp_deepspeed.initialize

deepspeed_plugin = self.state.deepspeed_plugin
deepspeed_plugin = get_active_deepspeed_plugin(self.state)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
result = [
Expand Down Expand Up @@ -1684,13 +1712,15 @@ def _prepare_deepspeed(self, *args):
)

config_kwargs = {
"train_micro_batch_size_per_gpu": batch_size_per_device,
"train_batch_size": batch_size_per_device
* deepspeed_plugin.get_value("gradient_accumulation_steps")
* self.num_processes,
"gradient_clipping": 1.0,
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
}
# This is skipped when preparing just a model
if batch_size_per_device is not None:
config_kwargs["train_micro_batch_size_per_gpu"] = batch_size_per_device
config_kwargs["train_batch_size"] = (
batch_size_per_device * deepspeed_plugin.get_value("gradient_accumulation_steps") * self.num_processes
)

model = None
optimizer = None
Expand Down Expand Up @@ -1852,16 +1882,19 @@ def _prepare_deepspeed(self, *args):
):
result[i] = scheduler
# pointing for deepspeed_engine_wrapped.backward()
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
if self.deepspeed_engine_wrapped is None:
self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
else:
logger.warning(
"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance. "
"If you want to call `accelerator.backward()` referencing a new model/engine, "
"please create a separate `Accelerator()` instance and call `accelerator.prepare()` on it."
)
self._models.append(engine)
if optimizer is not None:
self._optimizers.append(optimizer)
if scheduler is not None:
self._schedulers.append(scheduler)
if len(self._models) > 1:
raise AssertionError(
"You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
)
return tuple(result)

def _prepare_megatron_lm(self, *args):
Expand Down
32 changes: 28 additions & 4 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def __init__(
self.__dict__.update(PartialState._shared_state)
self._check_initialized(mixed_precision, cpu)
if not self.initialized:
self.deepspeed_plugin = None
self.deepspeed_plugins = []
self.use_ipex = None
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
Expand Down Expand Up @@ -900,7 +900,11 @@ def __init__(
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
self.deepspeed_plugin = deepspeed_plugin
# Just incase a user manually creates an `AcceleratorState`/bypasses the `Accelerator`
if not isinstance(deepspeed_plugin, (list, tuple)):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
deepspeed_plugin = [deepspeed_plugin]
self.deepspeed_plugins = deepspeed_plugin
self.distributed_type = DistributedType.DEEPSPEED
elif self.distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
Expand Down Expand Up @@ -946,7 +950,10 @@ def initialized(self) -> bool:
def __repr__(self):
repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
if self.distributed_type == DistributedType.DEEPSPEED:
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

active_plugin = get_active_deepspeed_plugin(self)
repr += f"ds_config: {active_plugin.deepspeed_config}\n"
return repr

def _check_initialized(self, mixed_precision=None, cpu=None):
Expand Down Expand Up @@ -975,7 +982,10 @@ def use_fp16(self):
@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED:
config = self.deepspeed_plugin.deepspeed_config
from accelerate.utils.deepspeed import get_active_deepspeed_plugin

active_plugin = get_active_deepspeed_plugin(self)
config = active_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
elif config.get("bf16", {}).get("enabled", False):
Expand Down Expand Up @@ -1092,6 +1102,20 @@ def local_main_process_first(self):
with PartialState().local_main_process_first():
yield

@property
def deepspeed_plugin(self):
"""
Returns the currently active DeepSpeedPlugin.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

If using multiple plugins, the first one will be the active one by default. Manually call `plugin.enable()` to
activate a different plugin.
"""
if self.distributed_type != DistributedType.DEEPSPEED:
return None
from accelerate.utils.deepspeed import get_active_deepspeed_plugin

return get_active_deepspeed_plugin(self)

def print(self, *args, **kwargs):
PartialState().print(*args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
DummyOptim,
DummyScheduler,
HfDeepSpeedConfig,
get_active_deepspeed_plugin,
)

from .bnb import has_4bit_bnb_layers, load_and_quantize_model
Expand Down
31 changes: 25 additions & 6 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,9 @@ def __post_init__(self):
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
self.zero3_init_flag = False
# NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed
# to the `Accelerator`'s `deepspeed_plugin` param, *or* `plugin.enable()` is manually called
self.enabled = False
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

# Ignore if it's already set
if self.enable_msamp and "msamp" not in self.deepspeed_config:
Expand Down Expand Up @@ -1196,13 +1199,13 @@ def set_mixed_precision(self, mixed_precision):
def set_deepspeed_weakref(self):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
from .imports import is_transformers_available

ds_config = copy.deepcopy(self.deepspeed_config)
if self.zero3_init_flag:
if not is_transformers_available():
raise Exception(
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
"Please run `pip install transformers`."
)
ds_config = copy.deepcopy(self.deepspeed_config)
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
ds_config["gradient_accumulation_steps"] = 1
if (
Expand All @@ -1213,12 +1216,12 @@ def set_deepspeed_weakref(self):
if ds_config.get("train_batch_size", None) == "auto":
del ds_config["train_batch_size"]

if compare_versions("transformers", "<", "4.33"):
from transformers.deepspeed import HfDeepSpeedConfig
else:
from transformers.integrations import HfDeepSpeedConfig
if compare_versions("transformers", "<", "4.33"):
from transformers.deepspeed import HfDeepSpeedConfig
else:
from transformers.integrations import HfDeepSpeedConfig

self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa

def is_zero3_init_enabled(self):
return self.zero3_init_flag
Expand Down Expand Up @@ -1284,6 +1287,22 @@ def set_moe_leaf_modules(self, model):
transformer_moe_cls.append(transformer_cls)
set_z3_leaf_modules(model, transformer_moe_cls) # z3_leaf

def enable(self):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""
Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration
"""
self.set_deepspeed_weakref()
from accelerate.state import AcceleratorState

if AcceleratorState._shared_state != {}:
for plugin in AcceleratorState().deepspeed_plugins:
if plugin is not self:
plugin.disable()
self.enabled = True

def disable(self):
self.enabled = False


@dataclass
class FullyShardedDataParallelPlugin:
Expand Down
18 changes: 18 additions & 0 deletions src/accelerate/utils/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@

from ..optimizer import AcceleratedOptimizer
from ..scheduler import AcceleratedScheduler
from ..state import AcceleratorState
from .dataclasses import DeepSpeedPlugin, DistributedType


def get_active_deepspeed_plugin(state: AcceleratorState) -> DeepSpeedPlugin:
"""
Returns the currently active DeepSpeedPlugin.

Raises:
ValueError: If DeepSpeed was not enabled and this function is called.
"""
if state.distributed_type != DistributedType.DEEPSPEED:
raise ValueError(
"Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. "
"Please make sure that either `Accelerator` is configured for `deepspeed` "
"and make sure that a `DeepSpeedPlugin` has been enabled (`plugin.enable()`) before calling this."
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
)
return next(plugin for plugin in state.deepspeed_plugins if plugin.enabled)


class HfDeepSpeedConfig:
Expand Down
34 changes: 34 additions & 0 deletions tests/deepspeed/ds_config_zero3_model_only.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto"
},
"train_batch_size": 1
}
Loading
Loading