Skip to content

Commit

Permalink
Fixing #1791
Browse files Browse the repository at this point in the history
  • Loading branch information
will-maclean committed May 27, 2024
1 parent 6c00565 commit bd70612
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Bug Fixes:
- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore
and only loads the PyTorch parameters (@peteole)
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- `CallbackList` now sets the `.parent` attribute of child callbacks to its own `.parent`. This resolves https://github.com/DLR-RM/stable-baselines3/issues/1791 (will-maclean)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1661,4 +1662,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @will-maclean
4 changes: 4 additions & 0 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.model)

# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
# pass through the parent callback to all children
callback.parent = self.parent

def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,31 @@ def test_checkpoint_additional_info(tmp_path):
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)


def test_eval_callback_chaining():
class CustomCallback(BaseCallback):

def __init__(self, verbose=0):
super().__init__(verbose)

def _on_step(self):
# Draw a figure
return True

eval_env = gym.make("Pendulum-v1")

stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)

eval_callback = EvalCallback(
eval_env,
best_model_save_path="./logs/",
log_path="./logs/",
eval_freq=199,
deterministic=True,
render=False,
callback_on_new_best=CallbackList([CustomCallback(), stop_on_threshold_callback]),
)

model = PPO("MlpPolicy", "Pendulum-v1")
model.learn(200, callback=eval_callback)

0 comments on commit bd70612

Please sign in to comment.