Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MPS device when available #951

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ace0516
Use MPS device when available
araffin Jul 4, 2022
9ac6225
Merge branch 'master' into feat/mps-support
araffin Aug 13, 2022
2dcbef9
Update test
araffin Aug 13, 2022
b00ca7f
Merge branch 'master' into feat/mps-support
araffin Aug 16, 2022
06a2124
Merge branch 'master' into feat/mps-support
qgallouedec Sep 28, 2022
6d868c0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 4, 2022
8d79e96
Merge branch 'master' into feat/mps-support
qgallouedec Oct 7, 2022
3276cb0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 10, 2022
f4f6073
Merge branch 'master' into feat/mps-support
qgallouedec Oct 14, 2022
64327c7
Merge branch 'master' into feat/mps-support
qgallouedec Oct 17, 2022
0344c3c
Merge branch 'master' into feat/mps-support
araffin Oct 24, 2022
fa196ab
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2022
efd086e
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2022
7f11843
Merge branch 'master' into feat/mps-support
qgallouedec Dec 7, 2022
c60f681
Merge branch 'master' into feat/mps-support
qgallouedec Dec 20, 2022
92e8d11
Merge branch 'master' into feat/mps-support
araffin Jan 13, 2023
b235c8e
Merge branch 'master' into feat/mps-support
qgallouedec Feb 14, 2023
d4d0536
Merge branch 'master' into feat/mps-support
araffin Apr 3, 2023
0311b62
Merge branch 'master' into feat/mps-support
araffin Apr 21, 2023
086f79a
Merge branch 'master' into feat/mps-support
araffin May 3, 2023
fe606fc
Merge branch 'master' into feat/mps-support
araffin May 24, 2023
34f4819
Merge branch 'master' into feat/mps-support
qgallouedec Jun 30, 2023
ef39571
Merge branch 'master' into feat/mps-support
araffin Aug 17, 2023
d26324c
Merge branch 'master' into feat/mps-support
araffin Aug 30, 2023
1e5dc90
Merge branch 'master' into feat/mps-support
araffin Oct 6, 2023
40ed03c
mps.is_available -> mps.is_built
qgallouedec Oct 6, 2023
e83924b
docstring
qgallouedec Oct 6, 2023
b707480
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2023
81e3c63
Merge branch 'master' into feat/mps-support
araffin Nov 16, 2023
f0e54a7
Merge branch 'master' into feat/mps-support
araffin Jan 10, 2024
d47c586
Merge branch 'master' into feat/mps-support
araffin Apr 18, 2024
b85a2a5
Fix warning
araffin Apr 18, 2024
955382e
Merge branch 'master' into feat/mps-support
araffin Sep 18, 2024
263e657
Merge branch 'master' into feat/mps-support
araffin Oct 29, 2024
7c71688
Merge branch 'master' into feat/mps-support
araffin Nov 8, 2024
9489b1a
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ New Features:
- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala)
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
- Use MacOS Metal "mps" device when available

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -758,6 +759,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Save cloudpickle version


`SB3-Contrib`_
Expand Down
33 changes: 24 additions & 9 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.

:param seed:
:param using_cuda:
:param seed: Seed
:param using_cuda: Whether CUDA is currently used
"""
# Seed python RNG
random.seed(seed)
Expand Down Expand Up @@ -141,19 +141,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
For now, it supports only CPU and CUDA.
By default, it tries to use the GPU.

:param device: One for 'auto', 'cuda', 'cpu'
:param device: One of "auto", "cuda", "cpu",
or any PyTorch supported device (for instance "mps")
:return: Supported Pytorch device
"""
# Cuda by default
# MPS/CUDA by default
if device == "auto":
device = "cuda"
device = get_available_accelerator()
# Force conversion to th.device
device = th.device(device)

# Cuda not available
# CUDA not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")

Expand Down Expand Up @@ -518,6 +519,20 @@ def should_collect_more_steps(
)


def get_available_accelerator() -> str:
"""
Return the available accelerator
(currently checking only for CUDA and MPS device)
"""
if hasattr(th, "backends") and th.backends.mps.is_built():
# MacOS Metal GPU
return "mps"
elif th.cuda.is_available():
return "cuda"
else:
return "cpu"


def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
Expand All @@ -533,7 +548,7 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Accelerator": get_available_accelerator(),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,10 @@ def test_get_system_info():
assert info["Stable-Baselines3"] == str(sb3.__version__)
assert "Python" in info_str
assert "PyTorch" in info_str
assert "GPU Enabled" in info_str
assert "Accelerator" in info_str
assert "Numpy" in info_str
assert "Gym" in info_str
assert "Cloudpickle" in info_str


def test_is_vectorized_observation():
Expand Down
Loading