Skip to content

Commit

Permalink
Fix resource warning (#1742)
Browse files Browse the repository at this point in the history
* Fix resource warning

* Add test and update changelog

* Fix for new mypy version
  • Loading branch information
araffin authored Nov 16, 2023
1 parent b413f4c commit 23fbeb5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 26 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a11 (WIP)
Release 2.2.0a12 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Expand Down Expand Up @@ -32,6 +32,7 @@ Bug Fixes:
- Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD)
- Fixed check_env for Sequence observation space (@corentinlger)
- Prevents instantiating BitFlippingEnv with conflicting observation spaces (@kylesayrs)
- Fixed ResourceWarning when loading and saving models (files were not closed)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
44 changes: 23 additions & 21 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,28 @@ def save_to_zip_file(
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
with open_path(save_path, "w", verbose=0, suffix="zip") as save_path:
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
serialized_data = data_to_json(data)

# Create a zip-archive and write our objects there.
with zipfile.ZipFile(save_path, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
Expand Down Expand Up @@ -450,4 +450,6 @@ def load_from_zip_file(
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
load_path.close()
return data, params, pytorch_variables
6 changes: 3 additions & 3 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def step_async(self, actions: np.ndarray) -> None:
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]

def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results)
obs, self.reset_infos = zip(*results) # type: ignore[assignment]
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a11
2.2.0a12
9 changes: 9 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,12 @@ def test_dqn_target_update_interval(tmp_path):
model = DQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100


# Turn warnings into errors
@pytest.mark.filterwarnings("error")
def test_no_resource_warning(tmp_path):
# check that files are properly closed
# Create a PPO agent and save it
PPO("MlpPolicy", "CartPole-v1").save(tmp_path / "dqn_cartpole")
PPO.load(tmp_path / "dqn_cartpole")

0 comments on commit 23fbeb5

Please sign in to comment.