diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index d85781646d6..419aab4f0cf 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1421,7 +1421,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e current_device.index if isinstance(current_device, torch.device) else current_device ) - if torch.device(current_device_index) != self.device: + if self.device.type == "cpu" and is_bitsandbytes_multi_backend_available(): + # bnb with multi-backend supports CPU which don't need to check index. + pass + elif torch.device(current_device_index) != self.device: # if on the first device (GPU 0) we don't care if (self.device.index is not None) or (current_device_index != 0): raise ValueError(