From 2d64730e838605fac9f05a84f70774c14eef924b Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Sun, 5 May 2024 20:13:17 +0530 Subject: [PATCH] format + minor --- experiments/paper_demo.py | 10 +++-- experiments/util/collect_results.py | 1 + experiments/util/models.py | 64 ++++++++++++++++++----------- memsave_torch/nn/Linear.py | 7 +++- memsave_torch/nn/__init__.py | 7 +++- 5 files changed, 57 insertions(+), 32 deletions(-) diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index 3672564..e8df7c1 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -57,9 +57,10 @@ # ============== TRANSFORMER CONFIG ============== # Valid choices for models are in models.transformer_model_fns models = [ + "transformer", "gpt2", "bert", - # "bart", + "bart", "roberta", "t5", "flan-t5", @@ -109,9 +110,10 @@ "no_grad_linear_weights", "no_grad_linear_bias", ], - [ # LINEAR - "no_grad_conv_weights", - "no_grad_conv_bias", + [ # LLM + "grad_input", + "no_grad_linear_weights", + "no_grad_linear_bias", "no_grad_norm_weights", "no_grad_norm_bias", ], diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index eeb930f..3cc2d6b 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -33,6 +33,7 @@ "no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv", "no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm", "no_grad_conv_weights + no_grad_conv_bias + no_grad_norm_weights + no_grad_norm_bias": "Linear", + "grad_input + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "LLM", } diff --git a/experiments/util/models.py b/experiments/util/models.py index 4807b47..473696e 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -2,7 +2,8 @@ import itertools import math -from typing import List, Tuple +from functools import partial +from typing import Any, List, Tuple import torchvision.models as tvm from torch.nn import ( @@ -15,7 +16,14 @@ Sequential, Transformer, ) -from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForPreTraining, AutoModel +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForPreTraining, + AutoModelForSeq2SeqLM, + BartForConditionalGeneration, +) from memsave_torch.nn import ( MemSaveConv2d, @@ -296,18 +304,25 @@ def forward(self, loss_dict): # TRANSFORMER transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim) -# 'shortname': ('hf_repo_name', {extra_kwargs}) + +class _HF_model: + def __init__(self, hf_name: str, extra_kwargs: dict = {}, model_cls: Any = AutoModelForCausalLM) -> None: + self.hf_name = hf_name + self.extra_kwargs = extra_kwargs + self.model_cls = model_cls + + hf_transformers_models_map = { - "gpt2": ("gpt2", {}), - "vit": ("facebook/vit-mae-base", {}), - "bert": ("google-bert/bert-base-uncased", {'is_decoder': True}), - "bart": ("facebook/bart-base", {}), - "roberta": ("FacebookAI/roberta-base", {'is_decoder': True}), - "t5": ("google-t5/t5-base", {}), - "flan-t5": ("google/flan-t5-base", {}), - # "xlm-roberta": ("FacebookAI/xlm-roberta-base", {}), Needs work - # "mistral-7b": ("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm - # "llama3-8b": ("meta-llama/Meta-Llama-3-8B", {}), GATED + "gpt2": _HF_model("gpt2", {}), + "vit": _HF_model("facebook/vit-mae-base", {}, AutoModelForPreTraining), + "bert": _HF_model("google-bert/bert-base-uncased", {"is_decoder": True}), + "bart": _HF_model("facebook/bart-base", {}, BartForConditionalGeneration), + "roberta": _HF_model("FacebookAI/roberta-base", {"is_decoder": True}), + "t5": _HF_model("google-t5/t5-base", {}, AutoModelForSeq2SeqLM), + "flan-t5": _HF_model("google/flan-t5-base", {}, AutoModelForSeq2SeqLM), + "xlm-roberta": _HF_model("FacebookAI/xlm-roberta-base", {}, AutoModelForMaskedLM), + # "mistral-7b": _HF_model("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm + # "llama3-8b": _HF_model("meta-llama/Meta-Llama-3-8B", {}), GATED } hf_transformers_models = list(hf_transformers_models_map.keys()) hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models) @@ -315,25 +330,26 @@ def forward(self, loss_dict): transformer_model_fns = { "transformer": lambda: TorchTransformer(), "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), - "vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]), - "memsave_vit": lambda: convert_to_memory_saving( - AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]) - ), } -from functools import partial -fused = lambda name, kwargs: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name, **kwargs)) +fused = lambda cls, name, kwargs: convert_to_memory_saving( # noqa: E731 + cls.from_pretrained(name, **kwargs) +) for m in hf_transformers_models: if m in transformer_model_fns: continue # Can't use lambdas in loops :') - if not m.startswith('memsave_'): - hf_name, kwargs = hf_transformers_models_map[m] - transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name, **kwargs) + if not m.startswith("memsave_"): + props = hf_transformers_models_map[m] + transformer_model_fns[m] = partial( + props.model_cls.from_pretrained, props.hf_name, **props.extra_kwargs + ) else: - hf_name, kwargs = hf_transformers_models_map[m.split('memsave_', 1)[1]] - transformer_model_fns[m] = partial(fused, hf_name, kwargs) + props = hf_transformers_models_map[m.split("memsave_", 1)[1]] + transformer_model_fns[m] = partial( + fused, props.model_cls, props.hf_name, **props.extra_kwargs + ) class TorchTransformer(Module): diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index a59b3dc..1ba78c6 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -3,16 +3,19 @@ This is done by not saving the inputs/weights if weight/inputs dont require grad. """ -import torch.nn as nn import sys +import torch.nn as nn + from memsave_torch.nn.functional import linearMemSave transformers_imported = False -if 'transformers' in sys.modules: +if "transformers" in sys.modules: import transformers + transformers_imported = True + class MemSaveLinear(nn.Linear): """MemSaveLinear.""" diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index 6f26d80..4fbc756 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -6,9 +6,10 @@ - BatchNorm2d """ -import torch.nn as nn import sys +import torch.nn as nn + from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d @@ -20,10 +21,12 @@ from memsave_torch.nn.ReLU import MemSaveReLU transformers_imported = False -if 'transformers' in sys.modules: +if "transformers" in sys.modules: import transformers + transformers_imported = True + def convert_to_memory_saving( model: nn.Module, linear=True,