Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/fix-dummyvecenv-reset-args
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Apr 17, 2024
2 parents db14104 + 5623d98 commit 55db969
Show file tree
Hide file tree
Showing 21 changed files with 325 additions and 92 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ import gymnasium as gym

from stable_baselines3 import PPO

env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="human")

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
Expand Down
85 changes: 48 additions & 37 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,53 +31,52 @@ to do inference in another framework.
Export to ONNX
-----------------

As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.

The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:

For PPO, assuming a shared feature extractor.

.. warning::

The following example is for continuous actions only.
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
to obtain the action (e.g., convert action logits to action).
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
(clip or unscale the action to the correct space).


.. code-block:: python
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
self.policy = policy
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor.
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {"input": observation})
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
Expand All @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
Expand All @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
scaled_action = ort_sess.run(None, {"input": observation})[0]
print(scaled_action)
# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))
For more discussion around the topic, please refer to `GH#383 <https://github.com/DLR-RM/stable-baselines3/issues/383>`_ and `GH#1349 <https://github.com/DLR-RM/stable-baselines3/issues/1349>`_.


For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_

Trace/Export to C++
-------------------
Expand Down
11 changes: 10 additions & 1 deletion docs/guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ Installation

.. code-block:: bash
# Download model and save it into the logs/ folder
python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization)
TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/
# Push model, config and hyperparameters to the hub
Expand All @@ -86,12 +88,19 @@ For instance ``sb3/demo-hf-CartPole-v1``:

.. code-block:: python
import os
import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Allow the use of `pickle.load()` when downloading model from the hub
# Please make sure that the organization from which you download can be trusted
os.environ["TRUST_REMOTE_CODE"] = "True"
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
Expand Down
8 changes: 7 additions & 1 deletion docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di
Tips and Tricks when implementing an RL algorithm
=================================================

.. note::

We have a `video on YouTube about reliable RL <https://www.youtube.com/watch?v=7-PUg9EAa3Y>`_ that covers
this section in more details. You can also find the `slides online <https://araffin.github.io/slides/tips-reliable-rl/>`_.


When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research <http://joschu.net/docs/nuts-and-bolts.pdf>`_
by John Schulman are quite useful (`video <https://www.youtube.com/watch?v=8EcdaCk9KaQ>`_).

Expand Down Expand Up @@ -282,4 +288,4 @@ in RL with discrete actions:
3. Pong (one of the easiest Atari game)
4. other Atari games (e.g. Breakout)

.. _SBX: https://github.com/araffin/sbx
.. _SBX: https://github.com/araffin/sbx
86 changes: 79 additions & 7 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,48 @@
Changelog
==========

Release 2.3.0a1 (WIP)

Release 2.4.0a1 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fixed seed / options argument passing to environment resets in ``vec_env.reset()``

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Expanded the description for vec_env.reset seed and options passing




Release 2.3.0 (2024-03-31)
--------------------------

**New defaults hyperparameters for DDPG, TD3 and DQN**


Breaking Changes:
^^^^^^^^^^^^^^^^^
- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC``
Expand All @@ -19,48 +58,80 @@ Breaking Changes:
.. note::

Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3>`_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx>`_)
Two inconsistencies remain: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-Influence-of-policy-net--Vmlldzo2NDg1Mzk3>`_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr <https://wandb.ai/openrlbenchmark/sbx/reports/SBX-TD3-RL-Zoo-v2-3-0a0-vs-SB3-TD3-RL-Zoo-2-2-1---Vmlldzo2MjUyNTQx>`_)



- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms
- The default ``learning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms


.. code-block:: python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_start=50_000)
# model = DQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_start=100)
model = DQN("MlpPolicy", env, learning_starts=100)
- For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors,
policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work
- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, as ``pickle.load`` is not safe.


New Features:
^^^^^^^^^^^^^
- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger)

Bug Fixes:
^^^^^^^^^^
- Fixed seed / options argument passing to environment resets in ``vec_env.reset()``

- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger)


`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO
- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl)
- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations
- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations
- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations
- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered)


`RL Zoo`_
^^^^^^^^^
- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated)
- Added test dependencies to `setup.py` (@power-edge)
- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`)

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added support for ``MultiDiscrete`` and ``MultiBinary`` action spaces to PPO
- Added support for large values for gradient_steps to SAC, TD3, and TQC
- Fix ``train()`` signature and update type hints
- Fix replay buffer device at load time
- Added flatten layer
- Added ``CrossQ``

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Updated black from v23 to v24
- Updated ruff to >= v0.3.1
- Updated env checker for (multi)discrete spaces with non-zero start.

Documentation:
^^^^^^^^^^^^^^
- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
- Updated callback code example
- Expanded the description for vec_env.reset seed and options passing
- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video
- Added ``render_mode="human"`` in the README example (@marekm4)
- Fixed docstring signature for sum_independent_dims (@stagoverflow)
- Updated docstring description for ``log_interval`` in the base class (@rushitnshah).


Release 2.2.1 (2023-11-17)
--------------------------
Expand Down Expand Up @@ -1561,3 +1632,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah
2 changes: 1 addition & 1 deletion docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Notes

- Original paper: https://arxiv.org/abs/1707.06347
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
- OpenAI blog post: https://openai.com/research/openai-baselines-ppo
- Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html
- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/

Expand Down
Loading

0 comments on commit 55db969

Please sign in to comment.