diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0efc16e56..822e0cb3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -23,38 +23,40 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Build the doc - run: | - make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + pip install .[extra_no_roms,tests,docs] + # Use headless version + pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint + - name: Build the doc + run: | + make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + # Do not run for python 3.8 (mypy internal error) + if: matrix.python-version != '3.8' + - name: Test with pytest + run: | + make pytest diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 78eb2bd0e..31ff99d09 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a5 (WIP) +Release 2.4.0a6 (WIP) -------------------------- Breaking Changes: @@ -11,6 +11,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 8e20ffe00..dd435a33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"] [tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods -"./stable_baselines3/common/callbacks.py"= ["B027"] -"./stable_baselines3/common/noise.py"= ["B027"] +"./stable_baselines3/common/callbacks.py" = ["B027"] +"./stable_baselines3/common/noise.py" = ["B027"] # ClassVar, implicit optional check not needed for tests -"./tests/*.py"= ["RUF012", "RUF013"] +"./tests/*.py" = ["RUF012", "RUF013"] [tool.ruff.lint.mccabe] @@ -37,9 +37,7 @@ exclude = """(?x)( [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings @@ -47,23 +45,27 @@ filterwarnings = [ # Gymnasium warnings "ignore::UserWarning:gymnasium", # tqdm warning about rich being experimental - "ignore:rich is experimental" + "ignore:rich is experimental", ] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false omit = [ - "tests/*", - "setup.py", - # Require graphical interface - "stable_baselines3/common/results_plotter.py", - # Require ffmpeg - "stable_baselines3/common/vec_env/vec_video_recorder.py", + "tests/*", + "setup.py", + # Require graphical interface + "stable_baselines3/common/results_plotter.py", + # Require ffmpeg + "stable_baselines3/common/vec_env/vec_video_recorder.py", ] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index bb3ba5de8..234b91551 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import gymnasium as gym import torch as th @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. """ @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: @property def features_dim(self) -> int: + """The number of features that the extractor outputs.""" return self._features_dim @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor): Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. - :param observation_space: + :param observation_space: The observation space of the environment """ def __init__(self, observation_space: gym.Space) -> None: @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor): "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. :param normalized_image: Whether to assume that the image is already normalized @@ -113,13 +114,15 @@ def create_mlp( activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, + pre_linear_modules: Optional[List[Type[nn.Module]]] = None, + post_linear_modules: Optional[List[Type[nn.Module]]] = None, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: Dimension of the input vector - :param output_dim: + :param output_dim: Dimension of the output (last layer, for instance, the number of actions) :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. @@ -128,20 +131,52 @@ def create_mlp( :param squash_output: Whether to squash the output using a Tanh activation function :param with_bias: If set to False, the layers will not learn an additive bias - :return: + :param pre_linear_modules: List of nn.Module to add before the linear layers. + These modules should maintain the input tensor dimension (e.g. BatchNorm). + The number of input features is passed to the module's constructor. + Compared to post_linear_modules, they are used before the output layer (output_dim > 0). + :param post_linear_modules: List of nn.Module to add after the linear layers + (and before the activation function). These modules should maintain the input + tensor dimension (e.g. Dropout, LayerNorm). They are not used after the + output layer (output_dim > 0). The number of input features is passed to + the module's constructor. + :return: The list of layers of the neural network """ + pre_linear_modules = pre_linear_modules or [] + post_linear_modules = post_linear_modules or [] + + modules = [] if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] - else: - modules = [] + # BatchNorm maintains input dim + for module in pre_linear_modules: + modules.append(module(input_dim)) + + modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) + + # LayerNorm, Dropout maintain output dim + for module in post_linear_modules: + modules.append(module(net_arch[0])) + + modules.append(activation_fn()) for idx in range(len(net_arch) - 1): + for module in pre_linear_modules: + modules.append(module(net_arch[idx])) + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) + + for module in post_linear_modules: + modules.append(module(net_arch[idx + 1])) + modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + # Only add BatchNorm before output layer + for module in pre_linear_modules: + modules.append(module(last_layer_dim)) + modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a1fd35b5f..464a5c4dc 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a5 +2.4.0a6 diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 1f89b23d6..e92ffe8b7 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,8 +1,10 @@ import pytest import torch as th +import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike +from stable_baselines3.common.torch_layers import create_mlp @pytest.mark.parametrize( @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer(): def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) + + +def test_create_mlp(): + net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True) + # We cannot compare the network directly because the modules have different ids + # assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), + # nn.Tanh()] + assert len(net) == 6 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 4 + assert net[0].out_features == 16 + assert isinstance(net[1], nn.ReLU) + assert isinstance(net[2], nn.Linear) + assert isinstance(net[4], nn.Linear) + assert net[4].in_features == 8 + assert net[4].out_features == 2 + assert isinstance(net[5], nn.Tanh) + + # Linear network + net = create_mlp(4, -1, net_arch=[]) + assert net == [] + + # No output layer, with custom activation function + net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh) + # assert net == [nn.Linear(6, 8), nn.Tanh()] + assert len(net) == 2 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 6 + assert net[0].out_features == 8 + assert isinstance(net[1], nn.Tanh) + + # Using pre-linear and post-linear modules + pre_linear = [nn.BatchNorm1d] + post_linear = [nn.LayerNorm] + net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear) + # assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() + # nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), + # nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear + assert len(net) == 10 + assert isinstance(net[0], nn.BatchNorm1d) + assert net[0].num_features == 6 + assert isinstance(net[1], nn.Linear) + assert isinstance(net[2], nn.LayerNorm) + assert isinstance(net[3], nn.ReLU) + assert isinstance(net[4], nn.BatchNorm1d) + assert isinstance(net[5], nn.Linear) + assert net[5].in_features == 8 + assert net[5].out_features == 12 + assert isinstance(net[6], nn.LayerNorm) + assert isinstance(net[7], nn.ReLU) + assert isinstance(net[8], nn.BatchNorm1d) + assert isinstance(net[-1], nn.Linear) + assert net[-1].in_features == 12 + assert net[-1].out_features == 2