Skip to content

Commit

Permalink
Set CallbackList children's parent correctly (#1939)
Browse files Browse the repository at this point in the history
* Fixing #1791

* Update test and version

* Add test for callback after eval

* Fix mypy error

* Remove tqdm warnings

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
will-maclean and araffin authored Jun 7, 2024
1 parent 0b06d8a commit 4efee92
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 6 deletions.
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a2 (WIP)
Release 2.4.0a3 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -17,7 +17,8 @@ Bug Fixes:
- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore
and only loads the PyTorch parameters (@peteole)
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1662,4 +1663,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ filterwarnings = [
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# tqdm warning about rich being experimental
"ignore:rich is experimental"
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]

last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
Expand Down
4 changes: 4 additions & 0 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.model)

# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
# pass through the parent callback to all children
callback.parent = self.parent

def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def predict(
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a2
2.4.0a3
26 changes: 26 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path):
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)


def test_eval_callback_chaining(tmp_path):
class DummyCallback(BaseCallback):
def _on_step(self):
# Check that the parent callback is an EvalCallback
assert isinstance(self.parent, EvalCallback)
assert hasattr(self.parent, "best_mean_reward")
return True

stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)

eval_callback = EvalCallback(
gym.make("Pendulum-v1"),
best_model_save_path=tmp_path,
log_path=tmp_path,
eval_freq=32,
deterministic=True,
render=False,
callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]),
callback_after_eval=CallbackList([DummyCallback()]),
warn=False,
)

model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1)
model.learn(64, callback=eval_callback)

0 comments on commit 4efee92

Please sign in to comment.