Skip to content

Commit

Permalink
Update ruff and documentation for hf sb3 (#1866)
Browse files Browse the repository at this point in the history
* Update ruff

* Only load weights with `torch.load()` to avoid security issues

* Update doc about HF integration and remote code execution

* Fix doc build

* Revert weight_only=True for policies
  • Loading branch information
araffin authored Mar 11, 2024
1 parent f375cc3 commit 8b3723c
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 12 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
11 changes: 10 additions & 1 deletion docs/guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ Installation

.. code-block:: bash
# Download model and save it into the logs/ folder
python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization)
TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/
# Push model, config and hyperparameters to the hub
Expand All @@ -86,12 +88,19 @@ For instance ``sb3/demo-hf-CartPole-v1``:

.. code-block:: python
import os
import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Allow the use of `pickle.load()` when downloading model from the hub
# Please make sure that the organization from which you download can be trusted
os.environ["TRUST_REMOTE_CODE"] = "True"
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
Expand Down
16 changes: 13 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.3.0a3 (WIP)
Release 2.3.0a4 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -33,6 +33,11 @@ Breaking Changes:
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_start=100)
- For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors,
policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work
- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub,
as ``pickle.load`` is not safe.


New Features:
^^^^^^^^^^^^^
Expand All @@ -48,14 +53,19 @@ Bug Fixes:

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added support for ``MultiDiscrete`` and ``MultiBinary`` action spaces to PPO
- Added support for large values for gradient_steps to SAC, TD3, and TQC
- Fix ``train()`` signature and update type hints
- Fix replay buffer device at load time
- Added flatten layer

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Updated black from v23 to v24
- Updated ruff to >= v0.2.2
- Updated ruff to >= v0.3.1
- Updated env checker for (multi)discrete spaces with non-zero start.

Documentation:
Expand All @@ -66,7 +76,7 @@ Documentation:
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video
- Added ``render_mode="human"`` in the README example (@marekm4)
- Fixed docstring signature for sum_independent_dims (@stagoverflow)
- Updated docstring description for ``log_interval`` in the base class (@rushitnshah).
- Updated docstring description for ``log_interval`` in the base class (@rushitnshah).

Release 2.2.1 (2023-11-17)
--------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
# Type check
"mypy",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.2.2",
"ruff>=0.3.1",
# Reformat
"black>=24.2.0,<25",
],
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "a
:return:
"""
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Note(antonin): we cannot use `weights_only=True` here because we need to allow
# gymnasium imports for the policy to be loaded successfully
saved_variables = th.load(path, map_location=device, weights_only=False)

# Create policy object
model = cls(**saved_variables["data"])
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def load_from_zip_file(
file_content.seek(0)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
th_object = th.load(file_content, map_location=device)
th_object = th.load(file_content, map_location=device, weights_only=True)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0a3
2.3.0a4

0 comments on commit 8b3723c

Please sign in to comment.