From 4efee92fbad70f85aa094e27bd0a740274121795 Mon Sep 17 00:00:00 2001 From: will-maclean <41996719+will-maclean@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:07:28 +1000 Subject: [PATCH] Set CallbackList children's parent correctly (#1939) * Fixing #1791 * Update test and version * Add test for callback after eval * Fix mypy error * Remove tqdm warnings --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 7 ++++--- pyproject.toml | 2 ++ stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/callbacks.py | 4 ++++ stable_baselines3/common/policies.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_callbacks.py | 26 ++++++++++++++++++++++++++ 7 files changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 758615c39..d6df00956 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a2 (WIP) +Release 2.4.0a3 (WIP) -------------------------- Breaking Changes: @@ -17,7 +17,8 @@ Bug Fixes: - Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) -- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) +- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) +- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1662,4 +1663,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean diff --git a/pyproject.toml b/pyproject.toml index ce0a14e0f..8e20ffe00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ filterwarnings = [ "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", + # tqdm warning about rich being experimental + "ignore:rich is experimental" ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 651ecdb2d..b2fc5a710 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -419,7 +419,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra :param dones: if the last step was a terminal step (one bool for each env). """ # Convert to numpy - last_values = last_values.clone().cpu().numpy().flatten() + last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment] last_gae_lam = 0 for step in reversed(range(self.buffer_size)): diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 48b6011d1..c7841866b 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -204,6 +204,10 @@ def _init_callback(self) -> None: for callback in self.callbacks: callback.init_callback(self.model) + # Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791 + # pass through the parent callback to all children + callback.parent = self.parent + def _on_training_start(self) -> None: for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 3c9b14aaa..f9c4285dc 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -367,7 +367,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic) # Convert to numpy, and reshape to the original action shape - actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc] + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e828d3c3d..fdd5a5f23 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a2 +2.4.0a3 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d159c43e8..ffc37320f 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path): model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip") model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl") VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env) + + +def test_eval_callback_chaining(tmp_path): + class DummyCallback(BaseCallback): + def _on_step(self): + # Check that the parent callback is an EvalCallback + assert isinstance(self.parent, EvalCallback) + assert hasattr(self.parent, "best_mean_reward") + return True + + stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1) + + eval_callback = EvalCallback( + gym.make("Pendulum-v1"), + best_model_save_path=tmp_path, + log_path=tmp_path, + eval_freq=32, + deterministic=True, + render=False, + callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]), + callback_after_eval=CallbackList([DummyCallback()]), + warn=False, + ) + + model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1) + model.learn(64, callback=eval_callback)