You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue can be triggered in any case if in the arguments to deepspeed.initialize, parameters in optimizer.param_groups is not a subset of model.parameters.
At the fault location, the code is trying to access parameter's names stored in self.param_names using tensors in self.bit16_groups.
self.bit16_groups is populated from optimizer.param_groups, while
self.param_names is populated from the model itself.
Thus, if the optimizer's parameters are not exactly a subset of the model, a KeyError will be thrown. The case where optimizer's parameters are not exactly a subset of the model is quite common, due to optimization techniques like Parameter Grouping and ZeRO Optimization.
To Reproduce
We prepared a rather simple reproduction script to reproduce this error. In this script, deepspeed.initialize is accidently called twice. After the first deepspeed.initialize, optimizer.param_groups was consolidated into one single parameter, and causes key error in the second deepspeed.initialize.
Install deepspeed 0.15.4
run bug.py using deepspeed --num_gpus=1 bug.py
# bug.pyimporttorchimportdeepspeedimporttorch.nnasnn# Define a simple modelclassSimpleModel(nn.Module):
def__init__(self):
super(SimpleModel, self).__init__()
self.fc1=nn.Linear(10, 10)
self.fc2=nn.Linear(10, 5)
defforward(self, x):
x=self.fc1(x)
returnself.fc2(x)
# Main function to expose the bugdefexpose_bug():
# Initialize modelmodel=SimpleModel()
# Initialize DeepSpeed configurations for fp16ds_config_fp16= {
"train_micro_batch_size_per_gpu": 1,
"fp16": {"enabled": True,},
"zero_optimization": {"stage": 2}
}
optimizer=torch.optim.Adam(filter(lambdap: p.requires_grad, model.parameters()), lr=1e-3)
# optimizer have 4 params nowprint(optimizer.param_groups)
# Initialize DeepSpeed enginemodel_engine, optim, _, _=deepspeed.initialize(model=model, optimizer=optimizer, config_params=ds_config_fp16)
# optimizer have 1 params nowprint(optimizer.param_groups)
# EXCEPTION!!!model_engine, optim, _, _=deepspeed.initialize(model=model, optimizer=optimizer, config_params=ds_config_fp16)
if__name__=="__main__":
expose_bug()
Notice that the second deepspeed.initialize throws the KeyError exception.
Also notice that the first print of optimizer.param_groups shows 4 params, while the second print shows only one param (the content of one param is the merge of the 4 param).
prior to deepspeed.initialize
After deepspeed.initialize
Since in the second deepspeed.initialize, the merged param actually does not exist in the model, an KeyError will be thrown.
Expected behavior / Suggested Fix
We expect two behaviors here from DeepSpeed
Forbid deepspeed.initialize on models / optimizers that have already been used in another deepspeed.initialize.
Check for "optimizer.param_group should be a subset of model.parameters()" explicitly and throw a more user-friendly exception or warning.
ds_report output
Click to Show
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
[WARNING] using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/xxx/python3.10/site-packages/torch']
torch version .................... 2.2.2+cu121
deepspeed install path ........... ['/home/xxx/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.15.4, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 2.2, cuda 12.1
shared memory (/dev/shm) size .... 31.24 GB
I will be more than happy to contribute to the two suggested fixes, let me know what you think!
The text was updated successfully, but these errors were encountered:
Describe the bug
related to #3718
An KeyError is thrown inside
deepspeed.initialize
atruntime/zero/stage_1_and_2.py", line 574, in _create_param_mapping
, due to inconsistent usage of model parameters and parameters managed by the optimizer.Full Traceback (Click to Show)
Suspected Root Cause
deepspeed.initialize(model=model, optimizer=optimizer, config_params=ds_config_fp16)
This issue can be triggered in any case if in the arguments to
deepspeed.initialize
, parameters inoptimizer.param_groups
is not a subset ofmodel.parameters
.At the fault location, the code is trying to access parameter's names stored in
self.param_names
using tensors inself.bit16_groups
.self.bit16_groups
is populated fromoptimizer.param_groups
, whileself.param_names
is populated from the model itself.Thus, if the optimizer's parameters are not exactly a subset of the model, a
KeyError
will be thrown.The case where optimizer's parameters are not exactly a subset of the model is quite common, due to optimization techniques like Parameter Grouping and ZeRO Optimization.
To Reproduce
We prepared a rather simple reproduction script to reproduce this error. In this script,
deepspeed.initialize
is accidently called twice. After the firstdeepspeed.initialize
,optimizer.param_groups
was consolidated into one single parameter, and causes key error in the seconddeepspeed.initialize
.Install deepspeed
0.15.4
run
bug.py
usingdeepspeed --num_gpus=1 bug.py
Notice that the second
deepspeed.initialize
throws theKeyError
exception.Also notice that the first print of
optimizer.param_groups
shows 4 params, while the second print shows only one param (the content of one param is the merge of the 4 param).prior to
deepspeed.initialize
After
deepspeed.initialize
Since in the second
deepspeed.initialize
, the merged param actually does not exist in the model, an KeyError will be thrown.Expected behavior / Suggested Fix
We expect two behaviors here from DeepSpeed
deepspeed.initialize
on models / optimizers that have already been used in anotherdeepspeed.initialize
.optimizer.param_group
should be a subset ofmodel.parameters()
" explicitly and throw a more user-friendly exception or warning.ds_report output
Click to Show
I will be more than happy to contribute to the two suggested fixes, let me know what you think!
The text was updated successfully, but these errors were encountered: