From 6f86fb5edf02a52e7dd155c19ae58098e9103bcc Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Tue, 21 May 2024 14:43:45 +0530 Subject: [PATCH] MemSaveRMSNorm backward + RMSNorm tests + fix for different model dtypes --- experiments/paper_demo.py | 4 +- experiments/util/models.py | 3 ++ memsave_torch/nn/LayerNorm.py | 2 +- memsave_torch/nn/functional/LayerNorm.py | 27 +++++++++---- test/test_layers.py | 49 ++++++++++++++++++++++-- 5 files changed, 72 insertions(+), 13 deletions(-) diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index ffba1b6..3bf8460 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -132,7 +132,9 @@ collector.clear_file(estimate) for case in cases: pbar.update() - case_display = collect_results.case_mapping[collect_results.make_case_str(case)] + case_display = collect_results.case_mapping[ + collect_results.make_case_str(case) + ] case_str = f"--case {' '.join(case)}" if case is not None else "" pbar.set_description(f"{model} {estimate} case {case_display}") cmd = ( diff --git a/experiments/util/models.py b/experiments/util/models.py index 4901502..ff52fbb 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -433,6 +433,9 @@ def forward(self, x): Returns: output: model output """ + if self.model.dtype != torch.float32: + x = x.to(self.model.dtype) + # HF takes care of converting logits to float32 if self.dec: out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, **self.cache_kw) else: diff --git a/memsave_torch/nn/LayerNorm.py b/memsave_torch/nn/LayerNorm.py index 6a65f0d..345ed2b 100644 --- a/memsave_torch/nn/LayerNorm.py +++ b/memsave_torch/nn/LayerNorm.py @@ -168,7 +168,7 @@ def from_existing(cls, ln): Returns: obj: The MemSaveRMSLayerNorm object """ - if ln.variance_epsilon is not None: # T5LayerNorm + if getattr(ln, "variance_epsilon", None) is not None: # T5LayerNorm ln.eps = ln.variance_epsilon obj = cls( ln.weight.shape, diff --git a/memsave_torch/nn/functional/LayerNorm.py b/memsave_torch/nn/functional/LayerNorm.py index afef4b1..960aab1 100644 --- a/memsave_torch/nn/functional/LayerNorm.py +++ b/memsave_torch/nn/functional/LayerNorm.py @@ -84,24 +84,24 @@ def layer_normMemSave( class _MemSaveRMSLayerNorm(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, variance_epsilon): + def forward(ctx, x, weight, eps): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 - rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + variance_epsilon) + rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) need_grad = [] if ctx.needs_input_grad[0]: need_grad.append(weight) - if ctx.needs_input_grad[1]: + if any(ctx.needs_input_grad): need_grad.append(x) - ctx.rms_norm_inv = rms_norm_inv + # ctx.rms_norm_inv = rms_norm_inv + ctx.eps = eps ctx.save_for_backward(*need_grad) - ctx.hidden_size = weight.shape # import ipdb; ipdb.set_trace() return weight * x * rms_norm_inv @@ -113,14 +113,27 @@ def backward(ctx, grad_output): grad_x, grad_weight = None, None current_idx = 0 + rms_norm_inv = None if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 - grad_x = grad_output * weight * ctx.rms_norm_inv + x = ctx.saved_tensors[current_idx] + x_sq_sum = x.pow(2).sum(-1, keepdims=True) + n = x.shape[-1] + rms_norm_inv = torch.rsqrt(x_sq_sum / n + ctx.eps) + grad_x = ( + grad_output + * weight + * (1 / n) + * rms_norm_inv.pow(3) + * (x_sq_sum - x * (weight * x).sum(dim=-1, keepdims=True)) + ) if ctx.needs_input_grad[1]: x = ctx.saved_tensors[current_idx] + if rms_norm_inv is None: + rms_norm_inv = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + ctx.eps) current_idx += 1 - grad_weight = grad_output * x * ctx.rms_norm_inv + grad_weight = grad_output * x * rms_norm_inv return grad_x, grad_weight, None diff --git a/test/test_layers.py b/test/test_layers.py index 549e7cc..55e1078 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -4,6 +4,7 @@ import pytest import torch +import transformers import memsave_torch @@ -14,44 +15,82 @@ torch.manual_seed(239) cases = [ - {"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)}, { + "name": "Linear1dims", + "layer_fn": lambda: torch.nn.Linear(3, 5), + "data_fn": lambda: torch.rand(7, 3), + }, + { + "name": "Linear2dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 3), # weight sharing }, { + "name": "Linear3dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing }, { + "name": "Conv2d", "layer_fn": lambda: torch.nn.Conv2d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "Conv1d", "layer_fn": lambda: torch.nn.Conv1d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12), }, { + "name": "BatchNorm2d", "layer_fn": lambda: torch.nn.BatchNorm2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "LayerNorm", "layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "RMSLayerNorm", + "layer_fn": lambda: memsave_torch.nn.RMSLayerNorm([3, 12, 12]), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, + { + "name": "T5LayerNorm", + "layer_fn": lambda: transformers.models.t5.modeling_t5.T5LayerNorm([3, 12, 12]), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, + { + "name": "MistralRMSNorm", + "layer_fn": lambda: transformers.models.mistral.modeling_mistral.MistralRMSNorm( + [3, 12, 12] + ), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, + # TODO: add testing for dropout (save and load rng state) + # { + # "name": "Dropout" + # "layer_fn": lambda: torch.nn.Dropout(), + # "data_fn": lambda: torch.rand(7, 3, 12, 12), + # }, + { + "name": "MaxPool2d", "layer_fn": lambda: torch.nn.MaxPool2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, - {"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)}, + { + "name": "ReLU", + "layer_fn": lambda: torch.nn.ReLU(), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, ] @pytest.mark.quick -@pytest.mark.parametrize("case", cases) +@pytest.mark.parametrize("case", cases, ids=[case["name"] for case in cases]) @pytest.mark.parametrize("device", devices) def test_single_layer( - case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]], + case: Dict[str, Union[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]]], device: str, ): """Runs tests for the layer_cls defined by `layer` @@ -84,6 +123,8 @@ def test_single_layer( elif device == "cuda": atol = 1e-5 rtol = 1e-4 + if "RMS" in case["name"] or case["name"] == "T5LayerNorm": + atol = 1e-4 assert torch.allclose(y1, y2, rtol=rtol, atol=atol) assert torch.allclose(x1.grad, x2.grad, rtol=rtol, atol=atol) for p1, p2 in zip(layer.parameters(), memsave_layer.parameters()):