Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pre and post linear modules in create_mlp #1975

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 37 additions & 35 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,38 +23,40 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Changelog
==========

Release 2.4.0a5 (WIP)
Release 2.4.0a6 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)

Bug Fixes:
^^^^^^^^^^
Expand Down
32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"]

[tool.ruff.lint.per-file-ignores]
# Default implementation in abstract methods
"./stable_baselines3/common/callbacks.py"= ["B027"]
"./stable_baselines3/common/noise.py"= ["B027"]
"./stable_baselines3/common/callbacks.py" = ["B027"]
"./stable_baselines3/common/noise.py" = ["B027"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py"= ["RUF012", "RUF013"]
"./tests/*.py" = ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
Expand All @@ -37,33 +37,35 @@ exclude = """(?x)(

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# tqdm warning about rich being experimental
"ignore:rich is experimental"
"ignore:rich is experimental",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
]

[tool.coverage.run]
disable_warnings = ["couldnt-parse"]
branch = false
omit = [
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
"tests/*",
"setup.py",
# Require graphical interface
"stable_baselines3/common/results_plotter.py",
# Require ffmpeg
"stable_baselines3/common/vec_env/vec_video_recorder.py",
]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
49 changes: 40 additions & 9 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import torch as th
Expand All @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.

:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
"""

Expand All @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:

@property
def features_dim(self) -> int:
"""The number of features that the extractor outputs."""
return self._features_dim


Expand All @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.

:param observation_space:
:param observation_space: The observation space of the environment
"""

def __init__(self, observation_space: gym.Space) -> None:
Expand All @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor):
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.

:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
Expand Down Expand Up @@ -113,13 +114,15 @@ def create_mlp(
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.

:param input_dim: Dimension of the input vector
:param output_dim:
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
Expand All @@ -128,20 +131,48 @@ def create_mlp(
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:return:
:param pre_linear_modules: List of nn.Module to add before the linear layers,
for instance, BatchNorm layers.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
The number of input features is passed to the module's constructor.
:param post_linear_modules: List of nn.Module to add after the linear layers (and before the activation function),
for instance, Dropout or LayerNorm layers.
They are not used after the output layer (output_dim > 0).
The number of input features is passed to the module's constructor.
:return: The list of layers of the neural network
araffin marked this conversation as resolved.
Show resolved Hide resolved
"""

pre_linear_modules = pre_linear_modules or []
post_linear_modules = post_linear_modules or []

modules = []
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()]
else:
modules = []
for module in pre_linear_modules:
modules.append(module(input_dim))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the input dim the same for all modules?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get it. It only allows modules that have the same input/output dimension.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it is clearer with this test:

# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear

shall I add a comment to avoid confusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I've also suggested a new documentation: https://github.com/DLR-RM/stable-baselines3/pull/1975/files#r1686213647

modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[0]))

modules.append(activation_fn())

for idx in range(len(net_arch) - 1):
for module in pre_linear_modules:
modules.append(module(net_arch[idx]))

modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[idx + 1]))

modules.append(activation_fn())

if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
for module in pre_linear_modules:
modules.append(module(last_layer_dim))

modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a5
2.4.0a6
56 changes: 56 additions & 0 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch as th
import torch.nn as nn

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from stable_baselines3.common.torch_layers import create_mlp


@pytest.mark.parametrize(
Expand Down Expand Up @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer():
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)


def test_create_mlp():
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
# We cannot compare the network directly because the modules have different ids
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
# nn.Tanh()]
assert len(net) == 6
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 4
assert net[0].out_features == 16
assert isinstance(net[1], nn.ReLU)
assert isinstance(net[2], nn.Linear)
assert isinstance(net[4], nn.Linear)
assert net[4].in_features == 8
assert net[4].out_features == 2
assert isinstance(net[5], nn.Tanh)

# Linear network
net = create_mlp(4, -1, net_arch=[])
assert net == []

# No output layer, with custom activation function
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
# assert net == [nn.Linear(6, 8), nn.Tanh()]
assert len(net) == 2
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 6
assert net[0].out_features == 8
assert isinstance(net[1], nn.Tanh)

# Using pre-linear and post-linear modules
pre_linear = [nn.BatchNorm1d]
post_linear = [nn.LayerNorm]
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
assert len(net) == 10
assert isinstance(net[0], nn.BatchNorm1d)
assert net[0].num_features == 6
assert isinstance(net[1], nn.Linear)
assert isinstance(net[2], nn.LayerNorm)
assert isinstance(net[3], nn.ReLU)
assert isinstance(net[4], nn.BatchNorm1d)
assert isinstance(net[5], nn.Linear)
assert net[5].in_features == 8
assert net[5].out_features == 12
assert isinstance(net[6], nn.LayerNorm)
assert isinstance(net[7], nn.ReLU)
assert isinstance(net[8], nn.BatchNorm1d)
assert isinstance(net[-1], nn.Linear)
assert net[-1].in_features == 12
assert net[-1].out_features == 2
Loading