diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 634e712c1..1eba5f79a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu # Install Atari Roms pip install autorom diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 78d42eba3..53209cbac 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -59,7 +59,9 @@ Others: - Buffers do no call an additional ``.copy()`` when storing new transitions - Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument - Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``) +- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints - Fixed ``stable_baselines3/common/distributions.py`` type hints +- Switched to PyTorch 2.1.0 in the CI (fixes type annotations) Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 9476868c4..9c3489ae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ exclude = [ "stable_baselines3/common/vec_env/stacked_observations.py", "stable_baselines3/common/vec_env/subproc_vec_env.py", "stable_baselines3/common/vec_env/patch_gym.py", + "stable_baselines3/common/off_policy_algorithm.py", "stable_baselines3/common/distributions.py", ] @@ -44,8 +45,7 @@ ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/common/off_policy_algorithm.py$ - | stable_baselines3/common/policies.py$ + stable_baselines3/common/policies.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ | tests/test_logger.py$ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index e8dcac4a4..a24a4dc74 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -158,7 +158,7 @@ def _convert_train_freq(self) -> None: train_freq = (train_freq, "step") try: - train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) + train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment] except ValueError as e: raise ValueError( f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" @@ -167,7 +167,7 @@ def _convert_train_freq(self) -> None: if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") - self.train_freq = TrainFreq(*train_freq) + self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type] def _setup_model(self) -> None: self._setup_lr_schedule() @@ -242,7 +242,7 @@ def load_replay_buffer( if isinstance(self.replay_buffer, HerReplayBuffer): assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" - self.replay_buffer.set_env(self.get_env()) + self.replay_buffer.set_env(self.env) if truncate_last_traj: self.replay_buffer.truncate_last_trajectory() @@ -280,10 +280,12 @@ def _setup_learn( "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" "to avoid that issue." ) + assert replay_buffer is not None # for mypy # Go to the previous index pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size replay_buffer.dones[pos] = True + assert self.env is not None, "You must set the environment before calling _setup_learn()" # Vectorize action noise if needed if ( self.action_noise is not None @@ -319,6 +321,9 @@ def learn( callback.on_training_start(locals(), globals()) + assert self.env is not None, "You must set the environment before calling learn()" + assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn() + while self.num_timesteps < total_timesteps: rollout = self.collect_rollouts( self.env, @@ -381,6 +386,7 @@ def _sample_action( # Note: when using continuous actions, # we assume that the policy uses tanh to scale the action # We use non-deterministic action in the case of SAC, for TD3, it does not matter + assert self._last_obs is not None, "self._last_obs was not set" unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] @@ -404,6 +410,9 @@ def _dump_logs(self) -> None: """ Write log. """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") @@ -481,8 +490,8 @@ def _store_transition( next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :]) replay_buffer.add( - self._last_original_obs, - next_obs, + self._last_original_obs, # type: ignore[arg-type] + next_obs, # type: ignore[arg-type] buffer_action, reward_, dones, @@ -563,7 +572,7 @@ def collect_rollouts( self._update_info_buffer(infos, dones) # Store data in replay buffer (normalized action and unnormalized observation) - self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) + self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type] self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)