Skip to content

Commit

Permalink
fix: Follow PEP8 guidelines and evaluate falsy to truthy with not r…
Browse files Browse the repository at this point in the history
…ather than `is False`. (#1707)

* fix: Follow PEP8 guidelines and evaluate falsy to truth with `not` rather than `is False`.

https://docs.python.org/2/library/stdtypes.html#truth-value-testing

* chore: Update changelog inline with intent of changes in PR #1707

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix: Change `is False` to `not` as per PEP8

* chore: Remove superfluous comment about `is False`

* test: One On- and one Off-Policy algorithm (A2C and SAC respectively), with settings to speed up testing

* Update changelog

* chore: Remove EvalCallback as it's not actually required

* Update changelog.rst

* Rm duplicated "others" section in changelog.rst

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent c6bf251 commit 2ddf015
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Release 2.2.0a7 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1462,7 +1463,7 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def __init__(
if psutil is not None:
mem_available = psutil.virtual_memory().available

assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def __init__(self, reward_threshold: float, verbose: int = 0):

def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
print(
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def learn(
log_interval=log_interval,
)

if rollout.continue_training is False:
if not rollout.continue_training:
break

if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
Expand Down Expand Up @@ -556,7 +556,7 @@ def collect_rollouts(
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
if not callback.on_step():
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)

# Retrieve reward and episode length if using Monitor wrapper
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def collect_rollouts(

# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
if not callback.on_step():
return False

self._update_info_buffer(infos)
Expand Down Expand Up @@ -265,7 +265,7 @@ def learn(
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)

if continue_training is False:
if not continue_training:
break

iteration += 1
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))

def _update_features_extractor(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import gymnasium as gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
CheckpointCallback,
EvalCallback,
Expand Down Expand Up @@ -123,6 +125,40 @@ def test_eval_callback_vec_env():
assert eval_callback.last_mean_reward == 100.0


class AlwaysFailCallback(BaseCallback):
def __init__(self, *args, callback_false_value, **kwargs):
super().__init__(*args, **kwargs)
self.callback_false_value = callback_false_value

def _on_step(self) -> bool:
return self.callback_false_value


@pytest.mark.parametrize(
"model_class,model_kwargs",
[
(A2C, dict(n_steps=1, stats_window_size=1)),
(
SAC,
dict(
learning_starts=1,
buffer_size=1,
batch_size=1,
),
),
],
)
@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)])
def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value):
assert not callback_false_value # Sanity check to ensure parametrized values are valid
env_id = select_env(model_class)
model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2]))
alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value)
model.learn(10, callback=alwaysfailcallback)

assert alwaysfailcallback.n_calls == 1


def test_eval_success_logging(tmp_path):
n_bits = 2
n_envs = 2
Expand Down

0 comments on commit 2ddf015

Please sign in to comment.