Skip to content

Commit

Permalink
Fix VecEnv type hints (#1736)
Browse files Browse the repository at this point in the history
* Fix VecNormalize type hints

* Fix VecEnv utils type annotations

* Apply suggestions from code review

Co-authored-by: M. Ernestus <[email protected]>

* Remove PyType

---------

Co-authored-by: M. Ernestus <[email protected]>
  • Loading branch information
araffin and ernestum authored Nov 8, 2023
1 parent d671402 commit b413f4c
Show file tree
Hide file tree
Showing 15 changed files with 63 additions and 67 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ jobs:
- name: Type check
run: |
make type
# skip PyType, doesn't support 3.11 yet
if: "!(matrix.python-version == '3.11')"
- name: Test with pytest
run: |
make pytest
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ To run tests with `pytest`:
make pytest
```

Type checking with `pytype` and `mypy`:
Type checking with `mypy`:

```
make type
Expand Down
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh

pytype:
pytype -j auto

mypy:
mypy ${LINT_PATHS}

Expand All @@ -16,7 +13,7 @@ missing-annotations:
# missing docstrings
# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4

type: pytype mypy
type: mypy

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ Others:
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_normalize.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints
- Switched to ``mypy`` only for checking types

Documentation:
^^^^^^^^^^^^^^
Expand Down
21 changes: 1 addition & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,12 @@ max-complexity = 15
[tool.black]
line-length = 127

[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
# Checked with mypy
exclude = [
"stable_baselines3/common/buffers.py",
"stable_baselines3/common/base_class.py",
"stable_baselines3/common/callbacks.py",
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
"stable_baselines3/common/policies.py",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
"pytest-env",
"pytest-xdist",
# Type check
"pytype",
"mypy",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn

try:
import cv2 # pytype:disable=import-error
import cv2

cv2.ocl.setUseOpenCL(False)
except ImportError:
Expand Down
4 changes: 1 addition & 3 deletions stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ def __init__(
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(
self.file_handler, fieldnames=("r", "l", "t", *extra_keys)
) # pytype: disable=wrong-arg-types
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ def _setup_model(self) -> None:
device=self.device,
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**replay_buffer_kwargs, # pytype:disable=wrong-keyword-args
**replay_buffer_kwargs,
)

self.policy = self.policy_class( # pytype:disable=not-instantiable
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "a
saved_variables = th.load(path, map_location=device)

# Create policy object
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
model = cls(**saved_variables["data"])
# Load weights
model.load_state_dict(saved_variables["state_dict"])
model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym # pytype: disable=import-error
import gym as openai_gym

env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:
Expand Down
57 changes: 34 additions & 23 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from copy import deepcopy
from typing import Optional, Type, Union
from typing import Optional, Type, TypeVar

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
Expand All @@ -14,18 +13,16 @@
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder

# Avoid circular import
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)


def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env:
:param vec_wrapper_class:
:return:
:param env: The ``VecEnv`` that is going to be unwrapped
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
Expand All @@ -35,36 +32,50 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec
return None


def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
"""
:param env:
:return:
Retrieve a ``VecNormalize`` object by recursively searching.
:param env: The VecEnv that is going to be unwrapped
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
"""
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
return unwrap_vec_wrapper(env, VecNormalize)


def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped by a given ``VecEnvWrapper``.
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
:param env:
:param vec_wrapper_class:
:return:
:param env: The VecEnv that is going to be checked
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None


# Define here to avoid circular import
def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Sync eval env and train env when using VecNormalize
Synchronize the normalization statistics of an eval environment and train environment
when they are both wrapped in a ``VecNormalize`` wrapper.
:param env:
:param eval_env:
:param env: Training env
:param eval_env: Environment used for evaluation.
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
assert isinstance(eval_env_tmp, VecEnvWrapper), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
if isinstance(env_tmp, VecNormalize):
assert isinstance(eval_env_tmp, VecNormalize), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecNormalize but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:

if mode == "human":
# Display it using OpenCV
import cv2 # pytype:disable=import-error
import cv2

cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
observation = self.venv.reset()
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation
19 changes: 14 additions & 5 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class VecNormalize(VecEnvWrapper):
If not specified, all keys will be normalized.
"""

obs_spaces: Dict[str, spaces.Space]
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]

def __init__(
self,
venv: VecEnv,
Expand All @@ -47,11 +50,12 @@ def __init__(
self.norm_obs_keys = norm_obs_keys
# Check observation spaces
if self.norm_obs:
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
self._sanity_checks()

if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
Expand All @@ -64,8 +68,7 @@ def __init__(
)

else:
self.obs_spaces = None
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
Expand All @@ -92,7 +95,6 @@ def __init__(
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_obs = np.array([])
self.old_reward = np.array([])

def _sanity_checks(self) -> None:
Expand Down Expand Up @@ -148,7 +150,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state
self.venv = None
self.venv = None # type: ignore[assignment]

def set_venv(self, venv: VecEnv) -> None:
"""
Expand Down Expand Up @@ -177,6 +179,7 @@ def step_wait(self) -> VecEnvStepReturn:
where ``dones`` is a boolean vector indicating whether each element is new.
"""
obs, rewards, dones, infos = self.venv.step_wait()
assert isinstance(obs, (np.ndarray, dict)) # for mypy
self.old_obs = obs
self.old_reward = rewards

Expand Down Expand Up @@ -235,10 +238,12 @@ def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
return obs_

Expand All @@ -256,9 +261,11 @@ def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Unio
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
return obs_

Expand Down Expand Up @@ -286,13 +293,15 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
:return: first observation of the episode
"""
obs = self.venv.reset()
assert isinstance(obs, (np.ndarray, dict))
self.old_obs = obs
self.returns = np.zeros(self.num_envs)
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
self.obs_rms.update(obs)
return self.normalize_obs(obs)

Expand Down

0 comments on commit b413f4c

Please sign in to comment.