Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix offpolicy algo type hints #1734

Merged
merged 7 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ Others:
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ exclude = [
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
]

Expand All @@ -44,8 +45,7 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
Expand Down
21 changes: 15 additions & 6 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _convert_train_freq(self) -> None:
train_freq = (train_freq, "step")

try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment]
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
Expand All @@ -167,7 +167,7 @@ def _convert_train_freq(self) -> None:
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")

self.train_freq = TrainFreq(*train_freq)
self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type]

def _setup_model(self) -> None:
self._setup_lr_schedule()
Expand Down Expand Up @@ -242,7 +242,7 @@ def load_replay_buffer(

if isinstance(self.replay_buffer, HerReplayBuffer):
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
self.replay_buffer.set_env(self.get_env())
self.replay_buffer.set_env(self.env)
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()

Expand Down Expand Up @@ -280,10 +280,12 @@ def _setup_learn(
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True

assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
Expand Down Expand Up @@ -319,6 +321,9 @@ def learn(

callback.on_training_start(locals(), globals())

assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()

while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
Expand Down Expand Up @@ -381,6 +386,7 @@ def _sample_action(
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
Expand All @@ -404,6 +410,9 @@ def _dump_logs(self) -> None:
"""
Write log.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
Expand Down Expand Up @@ -481,8 +490,8 @@ def _store_transition(
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])

replay_buffer.add(
self._last_original_obs,
next_obs,
self._last_original_obs, # type: ignore[arg-type]
next_obs, # type: ignore[arg-type]
buffer_action,
reward_,
dones,
Expand Down Expand Up @@ -563,7 +572,7 @@ def collect_rollouts(
self._update_info_buffer(infos, dones)

# Store data in replay buffer (normalized action and unnormalized observation)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type]

self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

Expand Down