diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bf4e4867a..29f91cff2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a1 (WIP) +Release 2.4.0a2 (WIP) -------------------------- Breaking Changes: @@ -18,6 +18,7 @@ Bug Fixes: and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) - `CallbackList` now sets the `.parent` attribute of child callbacks to its own `.parent`. This resolves https://github.com/DLR-RM/stable-baselines3/issues/1791 (will-maclean) +- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1662,4 +1663,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @will-maclean +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4be61d65e..b2c967405 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -692,10 +692,9 @@ def load( # noqa: C901 if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] # backward compatibility, convert to new format - if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0: - saved_net_arch = data["policy_kwargs"]["net_arch"] - if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): - data["policy_kwargs"]["net_arch"] = saved_net_arch[0] + saved_net_arch = data["policy_kwargs"].get("net_arch") + if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): + data["policy_kwargs"]["net_arch"] = saved_net_arch[0] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 48adc0106..e828d3c3d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a1 +2.4.0a2 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 0162e3650..c7df7b26f 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -797,3 +797,15 @@ def test_cast_lr_schedule(tmp_path): model = PPO.load(tmp_path / "ppo.zip") assert type(model.lr_schedule(1.0)) is float # noqa: E721 assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) + + +def test_save_load_net_arch_none(tmp_path): + """ + Test that the model is loaded correctly when net_arch is manually set to None. + See GH#1928 + """ + PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip") + model = PPO.load(tmp_path / "ppo.zip") + # None has been replaced by the default net arch + assert model.policy.net_arch is not None + os.remove(tmp_path / "ppo.zip")