Skip to content

Commit

Permalink
MemSaveRMSNorm backward + RMSNorm tests + fix for different model dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed May 21, 2024
1 parent 336a5d5 commit 6f86fb5
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 13 deletions.
4 changes: 3 additions & 1 deletion experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@
collector.clear_file(estimate)
for case in cases:
pbar.update()
case_display = collect_results.case_mapping[collect_results.make_case_str(case)]
case_display = collect_results.case_mapping[
collect_results.make_case_str(case)
]
case_str = f"--case {' '.join(case)}" if case is not None else ""
pbar.set_description(f"{model} {estimate} case {case_display}")
cmd = (
Expand Down
3 changes: 3 additions & 0 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def forward(self, x):
Returns:
output: model output
"""
if self.model.dtype != torch.float32:
x = x.to(self.model.dtype)
# HF takes care of converting logits to float32
if self.dec:
out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, **self.cache_kw)
else:
Expand Down
2 changes: 1 addition & 1 deletion memsave_torch/nn/LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def from_existing(cls, ln):
Returns:
obj: The MemSaveRMSLayerNorm object
"""
if ln.variance_epsilon is not None: # T5LayerNorm
if getattr(ln, "variance_epsilon", None) is not None: # T5LayerNorm
ln.eps = ln.variance_epsilon
obj = cls(
ln.weight.shape,
Expand Down
27 changes: 20 additions & 7 deletions memsave_torch/nn/functional/LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,24 @@ def layer_normMemSave(

class _MemSaveRMSLayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, variance_epsilon):
def forward(ctx, x, weight, eps):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32

rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + variance_epsilon)
rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

need_grad = []
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[1]:
if any(ctx.needs_input_grad):
need_grad.append(x)
ctx.rms_norm_inv = rms_norm_inv
# ctx.rms_norm_inv = rms_norm_inv
ctx.eps = eps

ctx.save_for_backward(*need_grad)

ctx.hidden_size = weight.shape
# import ipdb; ipdb.set_trace()

return weight * x * rms_norm_inv
Expand All @@ -113,14 +113,27 @@ def backward(ctx, grad_output):
grad_x, grad_weight = None, None

current_idx = 0
rms_norm_inv = None
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
grad_x = grad_output * weight * ctx.rms_norm_inv
x = ctx.saved_tensors[current_idx]
x_sq_sum = x.pow(2).sum(-1, keepdims=True)
n = x.shape[-1]
rms_norm_inv = torch.rsqrt(x_sq_sum / n + ctx.eps)
grad_x = (
grad_output
* weight
* (1 / n)
* rms_norm_inv.pow(3)
* (x_sq_sum - x * (weight * x).sum(dim=-1, keepdims=True))
)
if ctx.needs_input_grad[1]:
x = ctx.saved_tensors[current_idx]
if rms_norm_inv is None:
rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + ctx.eps)
current_idx += 1
grad_weight = grad_output * x * ctx.rms_norm_inv
grad_weight = grad_output * x * rms_norm_inv

return grad_x, grad_weight, None

Expand Down
49 changes: 45 additions & 4 deletions test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
import transformers

import memsave_torch

Expand All @@ -14,44 +15,82 @@
torch.manual_seed(239)

cases = [
{"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)},
{
"name": "Linear1dims",
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 3),
},
{
"name": "Linear2dims",
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 3), # weight sharing
},
{
"name": "Linear3dims",
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing
},
{
"name": "Conv2d",
"layer_fn": lambda: torch.nn.Conv2d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"name": "Conv1d",
"layer_fn": lambda: torch.nn.Conv1d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12),
},
{
"name": "BatchNorm2d",
"layer_fn": lambda: torch.nn.BatchNorm2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"name": "LayerNorm",
"layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"name": "RMSLayerNorm",
"layer_fn": lambda: memsave_torch.nn.RMSLayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"name": "T5LayerNorm",
"layer_fn": lambda: transformers.models.t5.modeling_t5.T5LayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"name": "MistralRMSNorm",
"layer_fn": lambda: transformers.models.mistral.modeling_mistral.MistralRMSNorm(
[3, 12, 12]
),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
# TODO: add testing for dropout (save and load rng state)
# {
# "name": "Dropout"
# "layer_fn": lambda: torch.nn.Dropout(),
# "data_fn": lambda: torch.rand(7, 3, 12, 12),
# },
{
"name": "MaxPool2d",
"layer_fn": lambda: torch.nn.MaxPool2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)},
{
"name": "ReLU",
"layer_fn": lambda: torch.nn.ReLU(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
]


@pytest.mark.quick
@pytest.mark.parametrize("case", cases)
@pytest.mark.parametrize("case", cases, ids=[case["name"] for case in cases])
@pytest.mark.parametrize("device", devices)
def test_single_layer(
case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]],
case: Dict[str, Union[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]]],
device: str,
):
"""Runs tests for the layer_cls defined by `layer`
Expand Down Expand Up @@ -84,6 +123,8 @@ def test_single_layer(
elif device == "cuda":
atol = 1e-5
rtol = 1e-4
if "RMS" in case["name"] or case["name"] == "T5LayerNorm":
atol = 1e-4
assert torch.allclose(y1, y2, rtol=rtol, atol=atol)
assert torch.allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
for p1, p2 in zip(layer.parameters(), memsave_layer.parameters()):
Expand Down

0 comments on commit 6f86fb5

Please sign in to comment.