From ace05162c5b133b1b706ed5fcceecea6ccfdb004 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 4 Jul 2022 14:51:46 +0200 Subject: [PATCH 1/5] Use MPS device when available --- docs/misc/changelog.rst | 4 +++- stable_baselines3/common/utils.py | 31 ++++++++++++++++++++++++------- stable_baselines3/version.txt | 2 +- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 52bf3e48e..576fbc85c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a9 (WIP) +Release 1.5.1a10 (WIP) --------------------------- Breaking Changes: @@ -17,6 +17,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Use MacOS Metal "mps" device when available +- Save cloudpickle version SB3-Contrib ^^^^^^^^^^^ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 94cd65827..fdd33a433 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -6,6 +6,7 @@ from itertools import zip_longest from typing import Dict, Iterable, Optional, Tuple, Union +import cloudpickle import gym import numpy as np import torch as th @@ -135,19 +136,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. - For now, it supports only cpu and cuda. - By default, it tries to use the gpu. + For now, it supports only CPU and CUDA. + By default, it tries to use the GPU. - :param device: One for 'auto', 'cuda', 'cpu' + :param device: One of "auto", "cuda", "cpu", + or any PyTorch supported device (for instance "mps") :return: """ - # Cuda by default + # MPS/CUDA by default if device == "auto": - device = "cuda" + device = get_available_accelerator() # Force conversion to th.device device = th.device(device) - # Cuda not available + # CUDA not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") @@ -483,6 +485,20 @@ def should_collect_more_steps( ) +def get_available_accelerator() -> str: + """ + Return the available accelerator + (currently checking only for CUDA and MPS device) + """ + if hasattr(th, "has_mps") and th.backends.mps.is_available(): + # MacOS Metal GPU + return "mps" + elif th.cuda.is_available(): + return "cuda" + else: + return "cpu" + + def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: """ Retrieve system and python env info for the current system. @@ -496,9 +512,10 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, - "GPU Enabled": str(th.cuda.is_available()), + "Accelerator": get_available_accelerator(), "Numpy": np.__version__, "Gym": gym.__version__, + "Cloudpickle": cloudpickle.__version__, } env_info_str = "" for key, value in env_info.items(): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 125ec275d..c43063fe6 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a9 +1.5.1a10 From 2dcbef99c1f3f48ce1dfa12dcaf57388f9bce7e9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 13 Aug 2022 15:22:13 +0200 Subject: [PATCH 2/5] Update test --- tests/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 67f2ad1a3..0d1c144e6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -384,9 +384,10 @@ def test_get_system_info(): assert info["Stable-Baselines3"] == str(sb3.__version__) assert "Python" in info_str assert "PyTorch" in info_str - assert "GPU Enabled" in info_str + assert "Accelerator" in info_str assert "Numpy" in info_str assert "Gym" in info_str + assert "Cloudpickle" in info_str def test_is_vectorized_observation(): From 40ed03cddb67907d33edd74f3bf9360d3b97981e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:45:24 +0200 Subject: [PATCH 3/5] mps.is_available -> mps.is_built --- stable_baselines3/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 892823355..97ad152ba 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -521,7 +521,7 @@ def get_available_accelerator() -> str: Return the available accelerator (currently checking only for CUDA and MPS device) """ - if hasattr(th, "has_mps") and th.backends.mps.is_available(): + if hasattr(th, "has_mps") and th.backends.mps.is_built(): # MacOS Metal GPU return "mps" elif th.cuda.is_available(): From e83924b35cbb1c840cca40a4d39f248edc7729e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:06:39 +0200 Subject: [PATCH 4/5] docstring --- stable_baselines3/common/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 97ad152ba..e028d2385 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -29,8 +29,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators. - :param seed: - :param using_cuda: + :param seed: Seed + :param using_cuda: Whether CUDA is currently used """ # Seed python RNG random.seed(seed) From b85a2a5101238198663e7080674405aedeede5c6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 18 Apr 2024 18:20:27 +0200 Subject: [PATCH 5/5] Fix warning --- stable_baselines3/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 1e16c9e44..ff83e0e4c 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -521,7 +521,7 @@ def get_available_accelerator() -> str: Return the available accelerator (currently checking only for CUDA and MPS device) """ - if hasattr(th, "has_mps") and th.backends.mps.is_built(): + if hasattr(th, "backends") and th.backends.mps.is_built(): # MacOS Metal GPU return "mps" elif th.cuda.is_available():