From b9b7ce25e050c5588f4117fa229708622788bf95 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 22 Jul 2024 11:39:11 +0200 Subject: [PATCH] Add some comments --- stable_baselines3/common/torch_layers.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f519bc9f5..234b91551 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -135,10 +135,10 @@ def create_mlp( These modules should maintain the input tensor dimension (e.g. BatchNorm). The number of input features is passed to the module's constructor. Compared to post_linear_modules, they are used before the output layer (output_dim > 0). - :param post_linear_modules: List of nn.Module to add after the linear layers - (and before the activation function). These modules should maintain the input - tensor dimension (e.g. Dropout, LayerNorm). They are not used after the - output layer (output_dim > 0). The number of input features is passed to + :param post_linear_modules: List of nn.Module to add after the linear layers + (and before the activation function). These modules should maintain the input + tensor dimension (e.g. Dropout, LayerNorm). They are not used after the + output layer (output_dim > 0). The number of input features is passed to the module's constructor. :return: The list of layers of the neural network """ @@ -148,11 +148,13 @@ def create_mlp( modules = [] if len(net_arch) > 0: + # BatchNorm maintains input dim for module in pre_linear_modules: modules.append(module(input_dim)) modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) + # LayerNorm, Dropout maintain output dim for module in post_linear_modules: modules.append(module(net_arch[0])) @@ -171,6 +173,7 @@ def create_mlp( if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + # Only add BatchNorm before output layer for module in pre_linear_modules: modules.append(module(last_layer_dim))