diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index ea99444d1..f3af49928 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -89,8 +89,8 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display. Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). -- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator, - you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward. +- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options, + you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``). - methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fada24827..616c85904 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ New Features: ^^^^^^^^^^^^^ - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) +- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss) Bug Fixes: ^^^^^^^^^^ @@ -1465,6 +1466,6 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 16518a102..5e06a5c0c 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,6 +1,7 @@ import inspect import warnings from abc import ABC, abstractmethod +from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import cloudpickle @@ -67,6 +68,8 @@ def __init__( self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] # seeds to be used in the next call to env.reset() self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] + # options to be used in the next call to env.reset() + self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)] try: render_modes = self.get_attr("render_mode") @@ -95,6 +98,12 @@ def _reset_seeds(self) -> None: """ self._seeds = [None for _ in range(self.num_envs)] + def _reset_options(self) -> None: + """ + Reset the options that are going to be used at the next reset. + """ + self._options = [{} for _ in range(self.num_envs)] + @abstractmethod def reset(self) -> VecEnvObs: """ @@ -283,6 +292,22 @@ def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: self._seeds = [seed + idx for idx in range(self.num_envs)] return self._seeds + def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + """ + Set environment options for all environments. + If a dict is passed instead of a list, the same options will be used for all environments. + WARNING: Those options will only be passed to the environment at the next reset. + + :param options: A dictionary of environment options to pass to each environment at the next reset. + """ + if options is None: + options = {} + # Use deepcopy to avoid side effects + if isinstance(options, dict): + self._options = deepcopy([options] * self.num_envs) + else: + self._options = deepcopy(options) + @property def unwrapped(self) -> "VecEnv": if isinstance(self, VecEnvWrapper): @@ -354,6 +379,9 @@ def step_wait(self) -> VecEnvStepReturn: def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: return self.venv.seed(seed) + def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + return self.venv.set_options(options) + def close(self) -> None: return self.venv.close() diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 290898167..15ecfb681 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -73,10 +73,12 @@ def step_wait(self) -> VecEnvStepReturn: def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx]) + maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {} + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options) self._save_obs(env_idx, obs) - # Seeds are only used once + # Seeds and options are only used once self._reset_seeds() + self._reset_options() return self._obs_from_buf() def close(self) -> None: diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index dbc7002f0..83758841b 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -42,7 +42,8 @@ def _worker( observation, reset_info = env.reset() remote.send((observation, reward, done, info, reset_info)) elif cmd == "reset": - observation, reset_info = env.reset(seed=data) + maybe_options = {"options": data[1]} if data[1] else {} + observation, reset_info = env.reset(seed=data[0], **maybe_options) remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render()) @@ -132,11 +133,12 @@ def step_wait(self) -> VecEnvStepReturn: def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): - remote.send(("reset", self._seeds[env_idx])) + remote.send(("reset", (self._seeds[env_idx], self._options[env_idx]))) results = [remote.recv() for remote in self.remotes] obs, self.reset_infos = zip(*results) - # Seeds are only used once + # Seeds and options are only used once self._reset_seeds() + self._reset_options() return _flatten_obs(obs, self.observation_space) def close(self) -> None: diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 61740c41d..a9516ae25 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -30,11 +30,13 @@ def __init__(self, space, render_mode: str = "rgb_array"): self.current_step = 0 self.ep_length = 4 self.render_mode = render_mode + self.current_options: Optional[Dict] = None def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): if seed is not None: self.seed(seed) self.current_step = 0 + self.current_options = options self._choose_next_state() return self.state, {} @@ -160,6 +162,25 @@ def make_env(): assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12] assert vec_env.get_attr("current_step", indices=[-1]) == [12] + # Checks that options are correctly passed + assert vec_env.get_attr("current_options")[0] is None + # Same options for all envs + options = {"hello": 1} + vec_env.set_options(options) + assert vec_env.get_attr("current_options")[0] is None + # Only effective at reset + vec_env.reset() + assert vec_env.get_attr("current_options") == [options] * N_ENVS + vec_env.reset() + # Options are reset + assert vec_env.get_attr("current_options")[0] is None + # Use a list of options, different for the first env + options = [{"hello": 1}] * N_ENVS + options[0] = {"other_option": 2} + vec_env.set_options(options) + vec_env.reset() + assert vec_env.get_attr("current_options") == options + vec_env.close() @@ -487,7 +508,14 @@ def make_env(): vec_env.seed(3) new_obs = vec_env.reset() assert np.allclose(new_obs, obs) - vec_env.close() + # Test with VecNormalize (VecEnvWrapper should call self.venv.seed()) + vec_normalize = VecNormalize(vec_env) + vec_normalize.seed(3) + obs = vec_env.reset() + vec_normalize.seed(3) + new_obs = vec_env.reset() + assert np.allclose(new_obs, obs) + vec_normalize.close() # Similar test but with make_vec_env vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0) vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)