From e3dea4b2e03da6fb7ea70db89602909081a7967b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 17 Nov 2023 23:50:23 +0100 Subject: [PATCH] Release 2.2.1: Hotfix file closing (#1754) * new closing policy * revert #1742 * Add tests and update changelog --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 13 ++++- stable_baselines3/common/save_util.py | 69 +++++++++++++++------------ stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 27 +++++++++++ 4 files changed, 78 insertions(+), 33 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3c2643674..364657cbf 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,16 @@ Changelog ========== -Release 2.2.0 (2023-11-16) +Release 2.2.1 (2023-11-17) -------------------------- **Support for options at reset, bug fixes and better error messages** +.. note:: + + SB3 v2.2.0 was yanked after a breaking change was found in `GH#1751 `_. + Please use SB3 v2.2.1 and not v2.2.0. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version @@ -32,7 +38,9 @@ Bug Fixes: - Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD) - Fixed check_env for Sequence observation space (@corentinlger) - Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs) -- Fixed ResourceWarning when loading and saving models (files were not closed) +- Fixed ResourceWarning when loading and saving models (files were not closed), please note that only path are closed automatically, + the behavior stay the same for tempfiles (they need to be closed manually), + the behavior is now consistent when loading/saving replay buffer `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -76,6 +84,7 @@ Others: - Switched to PyTorch 2.1.0 in the CI (fixes type annotations) - Fixed ``stable_baselines3/common/policies.py`` type hints - Switched to ``mypy`` only for checking types +- Added tests to check consistency when saving/loading files Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 40681b591..0cbf6d4e2 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -308,28 +308,31 @@ def save_to_zip_file( :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - with open_path(save_path, "w", verbose=0, suffix="zip") as save_path: - # data/params can be None, so do not - # try to serialize them blindly + file = open_path(save_path, "w", verbose=0, suffix="zip") + # data/params can be None, so do not + # try to serialize them blindly + if data is not None: + serialized_data = data_to_json(data) + + # Create a zip-archive and write our objects there. + with zipfile.ZipFile(file, mode="w") as archive: + # Do not try to save "None" elements if data is not None: - serialized_data = data_to_json(data) - - # Create a zip-archive and write our objects there. - with zipfile.ZipFile(save_path, mode="w") as archive: - # Do not try to save "None" elements - if data is not None: - archive.writestr("data", serialized_data) - if pytorch_variables is not None: - with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: - th.save(pytorch_variables, pytorch_variables_file) - if params is not None: - for file_name, dict_ in params.items(): - with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: - th.save(dict_, param_file) - # Save metadata: library version when file was saved - archive.writestr("_stable_baselines3_version", sb3.__version__) - # Save system info about the current python env - archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) + archive.writestr("data", serialized_data) + if pytorch_variables is not None: + with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: + th.save(pytorch_variables, pytorch_variables_file) + if params is not None: + for file_name, dict_ in params.items(): + with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: + th.save(dict_, param_file) + # Save metadata: library version when file was saved + archive.writestr("_stable_baselines3_version", sb3.__version__) + # Save system info about the current python env + archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) + + if isinstance(save_path, (str, pathlib.Path)): + file.close() def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None: @@ -344,10 +347,12 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver :param obj: The object to save. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler: - # Use protocol>=4 to support saving replay buffers >= 4Gb - # See https://docs.python.org/3/library/pickle.html - pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL) + file = open_path(path, "w", verbose=verbose, suffix="pkl") + # Use protocol>=4 to support saving replay buffers >= 4Gb + # See https://docs.python.org/3/library/pickle.html + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) + if isinstance(path, (str, pathlib.Path)): + file.close() def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any: @@ -360,8 +365,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in path actually exists. If path is a io.BufferedIOBase the path exists. :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler: - return pickle.load(file_handler) + file = open_path(path, "r", verbose=verbose, suffix="pkl") + obj = pickle.load(file) + if isinstance(path, (str, pathlib.Path)): + file.close() + return obj def load_from_zip_file( @@ -391,14 +399,14 @@ def load_from_zip_file( :return: Class parameters, model state_dicts (aka "params", dict of state_dict) and dict of pytorch variables """ - load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") + file = open_path(load_path, "r", verbose=verbose, suffix="zip") # set device to cpu if cuda is not available device = get_device(device=device) # Open the zip archive and load data try: - with zipfile.ZipFile(load_path) as archive: + with zipfile.ZipFile(file) as archive: namelist = archive.namelist() # If data or parameters is not in the # zip archive, assume they were stored @@ -451,5 +459,6 @@ def load_from_zip_file( # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e finally: - load_path.close() + if isinstance(load_path, (str, pathlib.Path)): + file.close() return data, params, pytorch_variables diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ccbccc3dc..c043eea77 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0 +2.2.1 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 778d944f9..e7123e984 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -3,6 +3,7 @@ import json import os import pathlib +import tempfile import warnings import zipfile from collections import OrderedDict @@ -752,7 +753,33 @@ def test_dqn_target_update_interval(tmp_path): # Turn warnings into errors @pytest.mark.filterwarnings("error") def test_no_resource_warning(tmp_path): + # Check behavior of save/load + # see https://github.com/DLR-RM/stable-baselines3/issues/1751 + # check that files are properly closed # Create a PPO agent and save it PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole") PPO.load(tmp_path / "dqn_cartpole") + + PPO("MlpPolicy", "CartPole-v1").save(str(tmp_path / "dqn_cartpole")) + PPO.load(str(tmp_path / "dqn_cartpole")) + + # Do the same but in memory, should not close the file + with tempfile.TemporaryFile() as fp: + PPO("MlpPolicy", "CartPole-v1").save(fp) + PPO.load(fp) + assert not fp.closed + + # Same but with replay buffer + model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200) + model.save_replay_buffer(tmp_path / "replay") + model.load_replay_buffer(tmp_path / "replay") + + model.save_replay_buffer(str(tmp_path / "replay")) + model.load_replay_buffer(str(tmp_path / "replay")) + + with tempfile.TemporaryFile() as fp: + model.save_replay_buffer(fp) + fp.seek(0) + model.load_replay_buffer(fp) + assert not fp.closed