From 2ddf015cd9840a2a1675f5208be6eb2e86e4d045 Mon Sep 17 00:00:00 2001 From: Jan-Hendrik Ewers Date: Mon, 9 Oct 2023 11:21:12 +0100 Subject: [PATCH] fix: Follow PEP8 guidelines and evaluate falsy to truthy with `not` rather than `is False`. (#1707) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: Follow PEP8 guidelines and evaluate falsy to truth with `not` rather than `is False`. https://docs.python.org/2/library/stdtypes.html#truth-value-testing * chore: Update changelog inline with intent of changes in PR #1707 Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * fix: Change `is False` to `not` as per PEP8 * chore: Remove superfluous comment about `is False` * test: One On- and one Off-Policy algorithm (A2C and SAC respectively), with settings to speed up testing * Update changelog * chore: Remove EvalCallback as it's not actually required * Update changelog.rst * Rm duplicated "others" section in changelog.rst --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 +- stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/callbacks.py | 1 - .../common/off_policy_algorithm.py | 4 +-- .../common/on_policy_algorithm.py | 4 +-- stable_baselines3/common/policies.py | 2 +- tests/test_callbacks.py | 36 +++++++++++++++++++ 7 files changed, 44 insertions(+), 8 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 80f1dbd24..fada24827 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,7 @@ Release 2.2.0a7 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version +- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle) New Features: ^^^^^^^^^^^^^ @@ -1462,7 +1463,7 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor +@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index a230a31b8..306b43571 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -561,7 +561,7 @@ def __init__( if psutil is not None: mem_available = psutil.virtual_memory().available - assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage" + assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage" # disabling as this adds quite a bit of complexity # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702 self.optimize_memory_usage = optimize_memory_usage diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 54f1b97e5..5089bba2b 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -554,7 +554,6 @@ def __init__(self, reward_threshold: float, verbose: int = 0): def _on_step(self) -> bool: assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``" - # Convert np.bool_ to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose >= 1 and not continue_training: print( diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 2caaf8e97..e8dcac4a4 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -330,7 +330,7 @@ def learn( log_interval=log_interval, ) - if rollout.continue_training is False: + if not rollout.continue_training: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: @@ -556,7 +556,7 @@ def collect_rollouts( # Give access to local variables callback.update_locals(locals()) # Only stop training if return value is False, not when it is None. - if callback.on_step() is False: + if not callback.on_step(): return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False) # Retrieve reward and episode length if using Monitor wrapper diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1e0f9e6c9..4f9bb0809 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -186,7 +186,7 @@ def collect_rollouts( # Give access to local variables callback.update_locals(locals()) - if callback.on_step() is False: + if not callback.on_step(): return False self._update_info_buffer(infos) @@ -265,7 +265,7 @@ def learn( while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) - if continue_training is False: + if not continue_training: break iteration += 1 diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 9bb7b11d2..5f57672f2 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -90,7 +90,7 @@ def __init__( self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs # Automatically deactivate dtype and bounds checks - if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): + if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): self.features_extractor_kwargs.update(dict(normalized_image=True)) def _update_features_extractor( diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index f8b0e5486..d159c43e8 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,9 +4,11 @@ import gymnasium as gym import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer from stable_baselines3.common.callbacks import ( + BaseCallback, CallbackList, CheckpointCallback, EvalCallback, @@ -123,6 +125,40 @@ def test_eval_callback_vec_env(): assert eval_callback.last_mean_reward == 100.0 +class AlwaysFailCallback(BaseCallback): + def __init__(self, *args, callback_false_value, **kwargs): + super().__init__(*args, **kwargs) + self.callback_false_value = callback_false_value + + def _on_step(self) -> bool: + return self.callback_false_value + + +@pytest.mark.parametrize( + "model_class,model_kwargs", + [ + (A2C, dict(n_steps=1, stats_window_size=1)), + ( + SAC, + dict( + learning_starts=1, + buffer_size=1, + batch_size=1, + ), + ), + ], +) +@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)]) +def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value): + assert not callback_false_value # Sanity check to ensure parametrized values are valid + env_id = select_env(model_class) + model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2])) + alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value) + model.learn(10, callback=alwaysfailcallback) + + assert alwaysfailcallback.n_calls == 1 + + def test_eval_success_logging(tmp_path): n_bits = 2 n_envs = 2