Skip to content

Commit

Permalink
Merge branch 'master' into feat/eval-callback-list
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jun 7, 2024
2 parents bd70612 + 0b06d8a commit fd74988
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a1 (WIP)
Release 2.4.0a2 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -18,6 +18,7 @@ Bug Fixes:
and only loads the PyTorch parameters (@peteole)
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- `CallbackList` now sets the `.parent` attribute of child callbacks to its own `.parent`. This resolves https://github.com/DLR-RM/stable-baselines3/issues/1791 (will-maclean)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1662,4 +1663,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @will-maclean
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
7 changes: 3 additions & 4 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,9 @@ def load( # noqa: C901
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
# backward compatibility, convert to new format
if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
saved_net_arch = data["policy_kwargs"]["net_arch"]
if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
saved_net_arch = data["policy_kwargs"].get("net_arch")
if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]

if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a1
2.4.0a2
12 changes: 12 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,15 @@ def test_cast_lr_schedule(tmp_path):
model = PPO.load(tmp_path / "ppo.zip")
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))


def test_save_load_net_arch_none(tmp_path):
"""
Test that the model is loaded correctly when net_arch is manually set to None.
See GH#1928
"""
PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
# None has been replaced by the default net arch
assert model.policy.net_arch is not None
os.remove(tmp_path / "ppo.zip")

0 comments on commit fd74988

Please sign in to comment.