Skip to content

Commit

Permalink
Add support for setting options at reset with VecEnv (#1606)
Browse files Browse the repository at this point in the history
* Update signatures, and test with options

* Update changelog and black formatting

* Finish implementation (fixes, doc, tests)

* Use deepcopy to avoid side effects (modif by reference)

* Fix for mypy

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
ReHoss and araffin authored Oct 23, 2023
1 parent 2ddf015 commit aab5459
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display.
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).

- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator,
you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward.
- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).

- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ New Features:
^^^^^^^^^^^^^
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -1465,6 +1466,6 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
28 changes: 28 additions & 0 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import cloudpickle
Expand Down Expand Up @@ -67,6 +68,8 @@ def __init__(
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
# options to be used in the next call to env.reset()
self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)]

try:
render_modes = self.get_attr("render_mode")
Expand Down Expand Up @@ -95,6 +98,12 @@ def _reset_seeds(self) -> None:
"""
self._seeds = [None for _ in range(self.num_envs)]

def _reset_options(self) -> None:
"""
Reset the options that are going to be used at the next reset.
"""
self._options = [{} for _ in range(self.num_envs)]

@abstractmethod
def reset(self) -> VecEnvObs:
"""
Expand Down Expand Up @@ -283,6 +292,22 @@ def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds

def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
"""
if options is None:
options = {}
# Use deepcopy to avoid side effects
if isinstance(options, dict):
self._options = deepcopy([options] * self.num_envs)
else:
self._options = deepcopy(options)

@property
def unwrapped(self) -> "VecEnv":
if isinstance(self, VecEnvWrapper):
Expand Down Expand Up @@ -354,6 +379,9 @@ def step_wait(self) -> VecEnvStepReturn:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)

def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
return self.venv.set_options(options)

def close(self) -> None:
return self.venv.close()

Expand Down
6 changes: 4 additions & 2 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def step_wait(self) -> VecEnvStepReturn:

def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx])
maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {}
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
self._save_obs(env_idx, obs)
# Seeds are only used once
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return self._obs_from_buf()

def close(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def _worker(
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
observation, reset_info = env.reset(seed=data)
maybe_options = {"options": data[1]} if data[1] else {}
observation, reset_info = env.reset(seed=data[0], **maybe_options)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
Expand Down Expand Up @@ -132,11 +133,12 @@ def step_wait(self) -> VecEnvStepReturn:

def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", self._seeds[env_idx]))
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results)
# Seeds are only used once
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _flatten_obs(obs, self.observation_space)

def close(self) -> None:
Expand Down
30 changes: 29 additions & 1 deletion tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def __init__(self, space, render_mode: str = "rgb_array"):
self.current_step = 0
self.ep_length = 4
self.render_mode = render_mode
self.current_options: Optional[Dict] = None

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
self.seed(seed)
self.current_step = 0
self.current_options = options
self._choose_next_state()
return self.state, {}

Expand Down Expand Up @@ -160,6 +162,25 @@ def make_env():
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
assert vec_env.get_attr("current_step", indices=[-1]) == [12]

# Checks that options are correctly passed
assert vec_env.get_attr("current_options")[0] is None
# Same options for all envs
options = {"hello": 1}
vec_env.set_options(options)
assert vec_env.get_attr("current_options")[0] is None
# Only effective at reset
vec_env.reset()
assert vec_env.get_attr("current_options") == [options] * N_ENVS
vec_env.reset()
# Options are reset
assert vec_env.get_attr("current_options")[0] is None
# Use a list of options, different for the first env
options = [{"hello": 1}] * N_ENVS
options[0] = {"other_option": 2}
vec_env.set_options(options)
vec_env.reset()
assert vec_env.get_attr("current_options") == options

vec_env.close()


Expand Down Expand Up @@ -487,7 +508,14 @@ def make_env():
vec_env.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
vec_env.close()
# Test with VecNormalize (VecEnvWrapper should call self.venv.seed())
vec_normalize = VecNormalize(vec_env)
vec_normalize.seed(3)
obs = vec_env.reset()
vec_normalize.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
vec_normalize.close()
# Similar test but with make_vec_env
vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
Expand Down

0 comments on commit aab5459

Please sign in to comment.