Skip to content

Commit

Permalink
Add test and update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Oct 27, 2023
1 parent cf8ad5d commit 4961b53
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).

- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``).

- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
Expand Down
9 changes: 5 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changelog

Release 2.2.0a9 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Breaking Changes:
^^^^^^^^^^^^^^^^^
Expand All @@ -16,7 +17,7 @@ New Features:
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to off-policy algorithms
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO)


Bug Fixes:
Expand All @@ -38,9 +39,9 @@ Bug Fixes:
`RL Zoo`_
^^^^^^^^^

`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added ``DDPG`` and ``TD3`` algorithms

Deprecations:
^^^^^^^^^^^^^
Expand Down
14 changes: 14 additions & 0 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch as th
from gymnasium import spaces

from stable_baselines3 import A2C
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -150,3 +151,16 @@ def test_device_buffer(replay_buffer_cls, device):
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device


def test_custom_rollout_buffer():
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())

with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))

with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))

with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)

0 comments on commit 4961b53

Please sign in to comment.