Skip to content

Commit

Permalink
Fix offpolicy algo type hints (#1734)
Browse files Browse the repository at this point in the history
* Fix offpolicy algo type hints

* Update PyTorch to have latest type hints

* Fix pip argument

* Try PyTorch 2.0.1

* Revert "Try PyTorch 2.0.1"

This reverts commit 0e0ead4.

* Update changelog
  • Loading branch information
araffin authored Nov 6, 2023
1 parent 018ea5a commit a35c08c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ Others:
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ exclude = [
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
]

Expand All @@ -44,8 +45,7 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
Expand Down
21 changes: 15 additions & 6 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _convert_train_freq(self) -> None:
train_freq = (train_freq, "step")

try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment]
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
Expand All @@ -167,7 +167,7 @@ def _convert_train_freq(self) -> None:
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

self.train_freq = TrainFreq(*train_freq)
self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type]

def _setup_model(self) -> None:
self._setup_lr_schedule()
Expand Down Expand Up @@ -242,7 +242,7 @@ def load_replay_buffer(

if isinstance(self.replay_buffer, HerReplayBuffer):
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
self.replay_buffer.set_env(self.get_env())
self.replay_buffer.set_env(self.env)
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()

Expand Down Expand Up @@ -280,10 +280,12 @@ def _setup_learn(
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True

assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
Expand Down Expand Up @@ -319,6 +321,9 @@ def learn(

callback.on_training_start(locals(), globals())

assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()

while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
Expand Down Expand Up @@ -381,6 +386,7 @@ def _sample_action(
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
Expand All @@ -404,6 +410,9 @@ def _dump_logs(self) -> None:
"""
Write log.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
Expand Down Expand Up @@ -481,8 +490,8 @@ def _store_transition(
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])

replay_buffer.add(
self._last_original_obs,
next_obs,
self._last_original_obs, # type: ignore[arg-type]
next_obs, # type: ignore[arg-type]
buffer_action,
reward_,
dones,
Expand Down Expand Up @@ -563,7 +572,7 @@ def collect_rollouts(
self._update_info_buffer(infos, dones)

# Store data in replay buffer (normalized action and unnormalized observation)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type]

self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

Expand Down

0 comments on commit a35c08c

Please sign in to comment.