-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for pre and post linear modules in create_mlp
#1975
Conversation
else: | ||
modules = [] | ||
for module in pre_linear_modules: | ||
modules.append(module(input_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the input dim the same for all modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I get it. It only allows modules that have the same input/output dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it is clearer with this test:
stable-baselines3/tests/test_custom_policy.py
Lines 102 to 104 in 3b84f71
# assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() | |
# nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), | |
# nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear |
shall I add a comment to avoid confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! I've also suggested a new documentation: https://github.com/DLR-RM/stable-baselines3/pull/1975/files#r1686213647
Co-authored-by: Quentin Gallouédec <[email protected]>
) * Add support for pre and post linear modules in `create_mlp` * Disable mypy for python 3.8 * Reformat toml file * Update docstring Co-authored-by: Quentin Gallouédec <[email protected]> * Add some comments --------- Co-authored-by: Quentin Gallouédec <[email protected]>
Description
Related to #1069 and #1036 and Stable-Baselines-Team/stable-baselines3-contrib#243 (DroQ and CrossQ)
I also added missing tests for
create_mlp
.Motivation and Context
closes #1069
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line