Skip to content

Commit

Permalink
Add some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jul 22, 2024
1 parent 73984f2 commit b9b7ce2
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def create_mlp(
These modules should maintain the input tensor dimension (e.g. BatchNorm).
The number of input features is passed to the module's constructor.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
:param post_linear_modules: List of nn.Module to add after the linear layers
(and before the activation function). These modules should maintain the input
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
output layer (output_dim > 0). The number of input features is passed to
:param post_linear_modules: List of nn.Module to add after the linear layers
(and before the activation function). These modules should maintain the input
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
output layer (output_dim > 0). The number of input features is passed to
the module's constructor.
:return: The list of layers of the neural network
"""
Expand All @@ -148,11 +148,13 @@ def create_mlp(

modules = []
if len(net_arch) > 0:
# BatchNorm maintains input dim
for module in pre_linear_modules:
modules.append(module(input_dim))

modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))

# LayerNorm, Dropout maintain output dim
for module in post_linear_modules:
modules.append(module(net_arch[0]))

Expand All @@ -171,6 +173,7 @@ def create_mlp(

if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
# Only add BatchNorm before output layer
for module in pre_linear_modules:
modules.append(module(last_layer_dim))

Expand Down

0 comments on commit b9b7ce2

Please sign in to comment.