Skip to content

Commit

Permalink
Add support for pre and post linear modules in create_mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jul 20, 2024
1 parent 1a69fc8 commit e036d7c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Changelog
==========

Release 2.4.0a5 (WIP)
Release 2.4.0a6 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)

Bug Fixes:
^^^^^^^^^^
Expand Down
49 changes: 40 additions & 9 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import torch as th
Expand All @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
"""

Expand All @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:

@property
def features_dim(self) -> int:
"""The number of features that the extractor outputs."""
return self._features_dim


Expand All @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
:param observation_space: The observation space of the environment
"""

def __init__(self, observation_space: gym.Space) -> None:
Expand All @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor):
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space:
:param observation_space: The observation space of the environment
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
Expand Down Expand Up @@ -113,13 +114,15 @@ def create_mlp(
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
pre_linear_modules: Optional[List[Type[nn.Module]]] = None,
post_linear_modules: Optional[List[Type[nn.Module]]] = None,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim:
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
Expand All @@ -128,20 +131,48 @@ def create_mlp(
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:return:
:param pre_linear_modules: List of nn.Module to add before the linear layers,
for instance, BatchNorm layers.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
The number of input features is passed to the module's constructor.
:param post_linear_modules: List of nn.Module to add after the linear layers (and before the activation function),
for instance, Dropout or LayerNorm layers.
They are not used after the output layer (output_dim > 0).
The number of input features is passed to the module's constructor.
:return: The list of layers of the neural network
"""

pre_linear_modules = pre_linear_modules or []
post_linear_modules = post_linear_modules or []

modules = []
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()]
else:
modules = []
for module in pre_linear_modules:
modules.append(module(input_dim))

modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[0]))

modules.append(activation_fn())

for idx in range(len(net_arch) - 1):
for module in pre_linear_modules:
modules.append(module(net_arch[idx]))

modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[idx + 1]))

modules.append(activation_fn())

if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
for module in pre_linear_modules:
modules.append(module(last_layer_dim))

modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a5
2.4.0a6
56 changes: 56 additions & 0 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch as th
import torch.nn as nn

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from stable_baselines3.common.torch_layers import create_mlp


@pytest.mark.parametrize(
Expand Down Expand Up @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer():
def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)


def test_create_mlp():
net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True)
# We cannot compare the network directly because the modules have different ids
# assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2),
# nn.Tanh()]
assert len(net) == 6
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 4
assert net[0].out_features == 16
assert isinstance(net[1], nn.ReLU)
assert isinstance(net[2], nn.Linear)
assert isinstance(net[4], nn.Linear)
assert net[4].in_features == 8
assert net[4].out_features == 2
assert isinstance(net[5], nn.Tanh)

# Linear network
net = create_mlp(4, -1, net_arch=[])
assert net == []

# No output layer, with custom activation function
net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh)
# assert net == [nn.Linear(6, 8), nn.Tanh()]
assert len(net) == 2
assert isinstance(net[0], nn.Linear)
assert net[0].in_features == 6
assert net[0].out_features == 8
assert isinstance(net[1], nn.Tanh)

# Using pre-linear and post-linear modules
pre_linear = [nn.BatchNorm1d]
post_linear = [nn.LayerNorm]
net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear)
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU()
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(),
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear
assert len(net) == 10
assert isinstance(net[0], nn.BatchNorm1d)
assert net[0].num_features == 6
assert isinstance(net[1], nn.Linear)
assert isinstance(net[2], nn.LayerNorm)
assert isinstance(net[3], nn.ReLU)
assert isinstance(net[4], nn.BatchNorm1d)
assert isinstance(net[5], nn.Linear)
assert net[5].in_features == 8
assert net[5].out_features == 12
assert isinstance(net[6], nn.LayerNorm)
assert isinstance(net[7], nn.ReLU)
assert isinstance(net[8], nn.BatchNorm1d)
assert isinstance(net[-1], nn.Linear)
assert net[-1].in_features == 12
assert net[-1].out_features == 2

0 comments on commit e036d7c

Please sign in to comment.