Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug 🐛] weight decay incorrectly applied to LayerNorm and Mamba A, D parameters #5

Open
Niccolo-Ajroldi opened this issue Dec 5, 2024 · 0 comments · May be fixed by #6
Open

[bug 🐛] weight decay incorrectly applied to LayerNorm and Mamba A, D parameters #5

Niccolo-Ajroldi opened this issue Dec 5, 2024 · 0 comments · May be fixed by #6

Comments

@Niccolo-Ajroldi
Copy link
Contributor

Niccolo-Ajroldi commented Dec 5, 2024

Description

The current implementation adds weight decay to all model parameters.

However:

  1. Mamba should not have weight decay on A_log and D:
    self.A_log._no_weight_decay = True
  2. It's common practice to not have weight decay on biases and normalization layers.

Fix

I think 1. is a more crucial issue, but we should also include 2. to reflect standard practices in Language Modelling.

#6 implements a fix for both.

It creates two different param_groups for parameters with and without weight decay (see 5ab076a):

decay_params, no_decay_params = [], []
for n, p in self.model.named_parameters():
    if p.requires_grad:
        if not getattr(p, '_no_weight_decay', False) and ("bias" not in n) and ("norm" not in n):
            decay_params.append(p)
        else:
            no_decay_params.append(p)
param_groups = [
    {"params": decay_params, "weight_decay": self.mad_config.weight_decay},
    {"params": no_decay_params, "weight_decay": 0.0},
]

# optimizer:
if self.mad_config.optimizer == 'adamw':
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=self.mad_config.lr
    )

To distinguish normalization layers from other modules, I had to give them a name in the model initialization.
This is achieved by replacing:

self.unembed = nn.Sequential(OrderedDict([
    ('norm', norm(layer_cfg['dim'])), 
    ('lm_head', nn.Linear(dim, vocab_size))
]))

with:

self.unembed = nn.Sequential(norm(layer_cfg['dim']), nn.Linear(dim, vocab_size))
    self.model.append(nn.Sequential(OrderedDict([
        ('norm', norm(layer_cfg['dim'])),
        ('layer', layer(**layer_cfg))
    ])))
@Niccolo-Ajroldi Niccolo-Ajroldi linked a pull request Dec 5, 2024 that will close this issue
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Weight decay on normalization layers and Mamba custom parameters Weight decay incorrectly applied to LayerNorm and Mamba A, D parameters Dec 9, 2024
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Weight decay incorrectly applied to LayerNorm and Mamba A, D parameters [bug🐛] weight decay incorrectly applied to LayerNorm and Mamba A, D parameters Dec 9, 2024
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title [bug🐛] weight decay incorrectly applied to LayerNorm and Mamba A, D parameters [bug 🐛] weight decay incorrectly applied to LayerNorm and Mamba A, D parameters Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant