diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1eba5f79a..b1078cd28 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,8 +55,6 @@ jobs: - name: Type check run: | make type - # skip PyType, doesn't support 3.11 yet - if: "!(matrix.python-version == '3.11')" - name: Test with pytest run: | make pytest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 80366b2d2..d295269a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,7 +79,7 @@ To run tests with `pytest`: make pytest ``` -Type checking with `pytype` and `mypy`: +Type checking with `mypy`: ``` make type diff --git a/Makefile b/Makefile index cb90f3170..fe9f6ae2e 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,6 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py pytest: ./scripts/run_tests.sh -pytype: - pytype -j auto - mypy: mypy ${LINT_PATHS} @@ -16,7 +13,7 @@ missing-annotations: # missing docstrings # pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4 -type: pytype mypy +type: mypy lint: # stop the build if there are Python syntax errors or undefined names diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3b09fadb5..24175bcd1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -61,8 +61,11 @@ Others: - Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``) - Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints - Fixed ``stable_baselines3/common/distributions.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/vec_normalize.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/__init__.py`` type hints - Switched to PyTorch 2.1.0 in the CI (fixes type annotations) - Fixed ``stable_baselines3/common/policies.py`` type hints +- Switched to ``mypy`` only for checking types Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index b15e5156b..1195687f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,31 +24,12 @@ max-complexity = 15 [tool.black] line-length = 127 -[tool.pytype] -inputs = ["stable_baselines3"] -disable = ["pyi-error"] -# Checked with mypy -exclude = [ - "stable_baselines3/common/buffers.py", - "stable_baselines3/common/base_class.py", - "stable_baselines3/common/callbacks.py", - "stable_baselines3/common/on_policy_algorithm.py", - "stable_baselines3/common/vec_env/stacked_observations.py", - "stable_baselines3/common/vec_env/subproc_vec_env.py", - "stable_baselines3/common/vec_env/patch_gym.py", - "stable_baselines3/common/off_policy_algorithm.py", - "stable_baselines3/common/distributions.py", - "stable_baselines3/common/policies.py", -] - [tool.mypy] ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/common/vec_env/__init__.py$ - | stable_baselines3/common/vec_env/vec_normalize.py$ - | tests/test_logger.py$ + tests/test_logger.py$ | tests/test_train_eval_mode.py$ )""" diff --git a/setup.py b/setup.py index 0a1ab3072..5e10ed66c 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,6 @@ "pytest-env", "pytest-xdist", # Type check - "pytype", "mypy", # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.0.288", diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 4e0e47722..bbdba9a3d 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -7,7 +7,7 @@ from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn try: - import cv2 # pytype:disable=import-error + import cv2 cv2.ocl.setUseOpenCL(False) except ImportError: diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index f421a4df2..5253954e8 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -193,9 +193,7 @@ def __init__( mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") - self.logger = csv.DictWriter( - self.file_handler, fieldnames=("r", "l", "t", *extra_keys) - ) # pytype: disable=wrong-arg-types + self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys)) if override_existing: self.file_handler.write(f"#{json.dumps(header)}\n") self.logger.writeheader() diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index a24a4dc74..c460d0236 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -193,14 +193,14 @@ def _setup_model(self) -> None: device=self.device, n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, - **replay_buffer_kwargs, # pytype:disable=wrong-keyword-args + **replay_buffer_kwargs, ) - self.policy = self.policy_class( # pytype:disable=not-instantiable + self.policy = self.policy_class( self.observation_space, self.action_space, self.lr_schedule, - **self.policy_kwargs, # pytype:disable=not-instantiable + **self.policy_kwargs, ) self.policy = self.policy.to(self.device) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 0d810effc..50be01c9e 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -176,7 +176,7 @@ def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "a saved_variables = th.load(path, map_location=device) # Create policy object - model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable + model = cls(**saved_variables["data"]) # Load weights model.load_state_dict(saved_variables["state_dict"]) model.to(device) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 4950822d4..3ff193786 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -536,7 +536,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: "Gymnasium": gym.__version__, } try: - import gym as openai_gym # pytype: disable=import-error + import gym as openai_gym env_info.update({"OpenAI Gym": openai_gym.__version__}) except ImportError: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 2c036373a..5f73d3978 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,6 +1,5 @@ -import typing from copy import deepcopy -from typing import Optional, Type, Union +from typing import Optional, Type, TypeVar from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -14,18 +13,16 @@ from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder -# Avoid circular import -if typing.TYPE_CHECKING: - from stable_baselines3.common.type_aliases import GymEnv +VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) -def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: +def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. - :param env: - :param vec_wrapper_class: - :return: + :param env: The ``VecEnv`` that is going to be unwrapped + :param vec_wrapper_class: The desired ``VecEnvWrapper`` class. + :return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise """ env_tmp = env while isinstance(env_tmp, VecEnvWrapper): @@ -35,36 +32,50 @@ def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[Vec return None -def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: +def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]: """ - :param env: - :return: + Retrieve a ``VecNormalize`` object by recursively searching. + + :param env: The VecEnv that is going to be unwrapped + :return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise """ - return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type + return unwrap_vec_wrapper(env, VecNormalize) -def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: +def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool: """ - Check if an environment is already wrapped by a given ``VecEnvWrapper``. + Check if an environment is already wrapped in a given ``VecEnvWrapper``. - :param env: - :param vec_wrapper_class: - :return: + :param env: The VecEnv that is going to be checked + :param vec_wrapper_class: The desired ``VecEnvWrapper`` class. + :return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise """ return unwrap_vec_wrapper(env, vec_wrapper_class) is not None -# Define here to avoid circular import -def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: +def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None: """ - Sync eval env and train env when using VecNormalize + Synchronize the normalization statistics of an eval environment and train environment + when they are both wrapped in a ``VecNormalize`` wrapper. - :param env: - :param eval_env: + :param env: Training env + :param eval_env: Environment used for evaluation. """ env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): + assert isinstance(eval_env_tmp, VecEnvWrapper), ( + "Error while synchronizing normalization stats: expected the eval env to be " + f"a VecEnvWrapper but got {eval_env_tmp} instead. " + "This is probably due to the training env not being wrapped the same way as the evaluation env. " + f"Training env type: {env_tmp}." + ) if isinstance(env_tmp, VecNormalize): + assert isinstance(eval_env_tmp, VecNormalize), ( + "Error while synchronizing normalization stats: expected the eval env to be " + f"a VecNormalize but got {eval_env_tmp} instead. " + "This is probably due to the training env not being wrapped the same way as the evaluation env. " + f"Training env type: {env_tmp}." + ) # Only synchronize if observation normalization exists if hasattr(env_tmp, "obs_rms"): eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 5e06a5c0c..8e0c8cc69 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -258,7 +258,7 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: if mode == "human": # Display it using OpenCV - import cv2 # pytype:disable=import-error + import cv2 cv2.imshow("vecenv", bigimg[:, :, ::-1]) cv2.waitKey(1) diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 239666464..d412a96a2 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -38,6 +38,6 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - observation = self.venv.reset() # pytype:disable=annotation-type-mismatch + observation = self.venv.reset() observation = self.stacked_obs.reset(observation) # type: ignore[arg-type] return observation diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 27c3d433a..391ce342d 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -29,6 +29,9 @@ class VecNormalize(VecEnvWrapper): If not specified, all keys will be normalized. """ + obs_spaces: Dict[str, spaces.Space] + old_obs: Union[np.ndarray, Dict[str, np.ndarray]] + def __init__( self, venv: VecEnv, @@ -47,11 +50,12 @@ def __init__( self.norm_obs_keys = norm_obs_keys # Check observation spaces if self.norm_obs: + # Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore... self._sanity_checks() if isinstance(self.observation_space, spaces.Dict): self.obs_spaces = self.observation_space.spaces - self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} + self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr] # Update observation space when using image # See explanation below and GH #1214 for key in self.obs_rms.keys(): @@ -64,8 +68,7 @@ def __init__( ) else: - self.obs_spaces = None - self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type] # Update observation space when using image # See GH #1214 # This is to raise proper error when @@ -92,7 +95,6 @@ def __init__( self.training = training self.norm_obs = norm_obs self.norm_reward = norm_reward - self.old_obs = np.array([]) self.old_reward = np.array([]) def _sanity_checks(self) -> None: @@ -148,7 +150,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) self.__dict__.update(state) assert "venv" not in state - self.venv = None + self.venv = None # type: ignore[assignment] def set_venv(self, venv: VecEnv) -> None: """ @@ -177,6 +179,7 @@ def step_wait(self) -> VecEnvStepReturn: where ``dones`` is a boolean vector indicating whether each element is new. """ obs, rewards, dones, infos = self.venv.step_wait() + assert isinstance(obs, (np.ndarray, dict)) # for mypy self.old_obs = obs self.old_reward = rewards @@ -235,10 +238,12 @@ def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[ obs_ = deepcopy(obs) if self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + assert self.norm_obs_keys is not None # Only normalize the specified keys for key in self.norm_obs_keys: obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) else: + assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) return obs_ @@ -256,9 +261,11 @@ def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Unio obs_ = deepcopy(obs) if self.norm_obs: if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + assert self.norm_obs_keys is not None for key in self.norm_obs_keys: obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) else: + assert isinstance(self.obs_rms, RunningMeanStd) obs_ = self._unnormalize_obs(obs, self.obs_rms) return obs_ @@ -286,6 +293,7 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: :return: first observation of the episode """ obs = self.venv.reset() + assert isinstance(obs, (np.ndarray, dict)) self.old_obs = obs self.returns = np.zeros(self.num_envs) if self.training and self.norm_obs: @@ -293,6 +301,7 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: for key in self.obs_rms.keys(): self.obs_rms[key].update(obs[key]) else: + assert isinstance(self.obs_rms, RunningMeanStd) self.obs_rms.update(obs) return self.normalize_obs(obs)