Skip to content

Commit

Permalink
Update env checker for spaces with non-zero start (#1845)
Browse files Browse the repository at this point in the history
* Update ruff

* Update env checker for non-zero start
  • Loading branch information
araffin authored Feb 19, 2024
1 parent 1cba1bb commit a8e9059
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero

Expand Down
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.3.0a2 (WIP)
Release 2.3.0a3 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -55,6 +55,8 @@ Deprecations:
Others:
^^^^^^^
- Updated black from v23 to v24
- Updated ruff to >= v0.2.2
- Updated env checker for (multi)discrete spaces with non-zero start.

Documentation:
^^^^^^^^^^^^^^
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
line-length = 127
# Assume Python 3.8
target-version = "py38"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# B028: Ignore explicit stacklevel`
# RUF013: Too many false positives (implicit optional)
ignore = ["B028", "RUF013"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Default implementation in abstract methods
"./stable_baselines3/common/callbacks.py"= ["B027"]
"./stable_baselines3/common/noise.py"= ["B027"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]


[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
# Type check
"mypy",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",
"ruff>=0.2.2",
# Reformat
"black>=24.2.0,<25",
],
Expand Down
44 changes: 28 additions & 16 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
return not isinstance(space, (spaces.Dict, spaces.Tuple))


def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
"""
Return False if a (Multi)Discrete space has a non-zero start.
"""
return np.allclose(space.start, np.zeros_like(space.start))


def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
"""
:param space: Observation or action space
:param space_type: information about whether it is an observation or action space
(for the warning message)
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
maybe_key = f"(key='{key}')" if key else ""
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
)


def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:key: When the observation space comes from a Dict space, we pass the
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
Expand Down Expand Up @@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
if isinstance(space, spaces.Discrete) and space.start != 0:
warnings.warn(
f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)
_check_non_zero_start(space, "observation", key)

if nested_dict:
warnings.warn(
Expand All @@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"which is supported by SB3."
)

if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0:
warnings.warn(
"Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)
_check_non_zero_start(observation_space, "observation")

if isinstance(observation_space, spaces.Sequence):
warnings.warn(
Expand All @@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"Note: The checks for returned values are skipped."
)

if isinstance(action_space, spaces.Discrete) and action_space.start != 0:
warnings.warn(
"Discrete action space with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your action space."
)
_check_non_zero_start(action_space, "action")

if not _is_numpy_array_space(action_space):
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0a2
2.3.0a3
6 changes: 5 additions & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def patched_step(_action):
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
],
Expand Down Expand Up @@ -164,6 +166,8 @@ def patched_step(_action):
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
],
)
def test_non_default_action_spaces(new_action_space):
Expand All @@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space):
env.action_space = new_action_space

# Discrete action space
if isinstance(new_action_space, spaces.Discrete):
if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)):
with pytest.warns(UserWarning):
check_env(env)
return
Expand Down

0 comments on commit a8e9059

Please sign in to comment.