Skip to content

Commit

Permalink
Fix offpolicy algo type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 5, 2023
1 parent 294f2b4 commit bd31db5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Others:
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints


Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ exclude = [
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py"
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
]

[tool.mypy]
Expand All @@ -44,7 +45,6 @@ follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
Expand Down
21 changes: 15 additions & 6 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _convert_train_freq(self) -> None:
train_freq = (train_freq, "step")

try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment]
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
Expand All @@ -167,7 +167,7 @@ def _convert_train_freq(self) -> None:
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

self.train_freq = TrainFreq(*train_freq)
self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type]

def _setup_model(self) -> None:
self._setup_lr_schedule()
Expand Down Expand Up @@ -242,7 +242,7 @@ def load_replay_buffer(

if isinstance(self.replay_buffer, HerReplayBuffer):
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
self.replay_buffer.set_env(self.get_env())
self.replay_buffer.set_env(self.env)
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()

Expand Down Expand Up @@ -280,10 +280,12 @@ def _setup_learn(
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True

assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
Expand Down Expand Up @@ -319,6 +321,9 @@ def learn(

callback.on_training_start(locals(), globals())

assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()

while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
Expand Down Expand Up @@ -381,6 +386,7 @@ def _sample_action(
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
Expand All @@ -404,6 +410,9 @@ def _dump_logs(self) -> None:
"""
Write log.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
Expand Down Expand Up @@ -481,8 +490,8 @@ def _store_transition(
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])

replay_buffer.add(
self._last_original_obs,
next_obs,
self._last_original_obs, # type: ignore[arg-type]
next_obs, # type: ignore[arg-type]
buffer_action,
reward_,
dones,
Expand Down Expand Up @@ -563,7 +572,7 @@ def collect_rollouts(
self._update_info_buffer(infos, dones)

# Store data in replay buffer (normalized action and unnormalized observation)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type]

self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

Expand Down

0 comments on commit bd31db5

Please sign in to comment.