From 9084a25b4c0332afb8d85834a3b09af09d5e7baa Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Thu, 2 May 2024 17:31:37 +0530 Subject: [PATCH] fix for encoder-decoder models + avoid loading transformers in `memsave_torch.nn` --- experiments/paper_demo.py | 4 +-- experiments/util/models.py | 48 +++++++++++++++++++++--------------- memsave_torch/nn/Linear.py | 10 ++++++-- memsave_torch/nn/__init__.py | 11 +++++++-- 4 files changed, 47 insertions(+), 26 deletions(-) diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index 5767851..3672564 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -59,7 +59,7 @@ models = [ "gpt2", "bert", - "bart", + # "bart", "roberta", "t5", "flan-t5", @@ -68,7 +68,7 @@ # "llama3-8b", ] models = prefix_in_pairs("memsave_", models) -batch_size = 8 +batch_size = 64 input_channels = 2048 input_HW = 256 num_classes = 5000 diff --git a/experiments/util/models.py b/experiments/util/models.py index 331a5c5..4807b47 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -94,8 +94,8 @@ def get_transformers_config(model_name: str) -> AutoConfig: """ if model_name.startswith("memsave_"): model_name = model_name.split("memsave_")[1] - model_hf_name = hf_transformers_models_map[model_name] - return AutoConfig.from_pretrained(model_hf_name) + model_hf_name, kwargs = hf_transformers_models_map[model_name] + return AutoConfig.from_pretrained(model_hf_name, **kwargs) # CONV @@ -296,17 +296,18 @@ def forward(self, loss_dict): # TRANSFORMER transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim) +# 'shortname': ('hf_repo_name', {extra_kwargs}) hf_transformers_models_map = { - "gpt2": "gpt2", - "vit": "facebook/vit-mae-base", - "bert": "google-bert/bert-base-uncased", - "bart": "facebook/bart-base", - "roberta": "FacebookAI/roberta-base", - "t5": "google-t5/t5-base", - "flan-t5": "google/flan-t5-base", - "xlm-roberta": "FacebookAI/xlm-roberta-base", - # "mistral-7b": "mistralai/Mistral-7B-v0.1", need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm - # "llama3-8b": "meta-llama/Meta-Llama-3-8B", GATED + "gpt2": ("gpt2", {}), + "vit": ("facebook/vit-mae-base", {}), + "bert": ("google-bert/bert-base-uncased", {'is_decoder': True}), + "bart": ("facebook/bart-base", {}), + "roberta": ("FacebookAI/roberta-base", {'is_decoder': True}), + "t5": ("google-t5/t5-base", {}), + "flan-t5": ("google/flan-t5-base", {}), + # "xlm-roberta": ("FacebookAI/xlm-roberta-base", {}), Needs work + # "mistral-7b": ("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm + # "llama3-8b": ("meta-llama/Meta-Llama-3-8B", {}), GATED } hf_transformers_models = list(hf_transformers_models_map.keys()) hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models) @@ -314,23 +315,25 @@ def forward(self, loss_dict): transformer_model_fns = { "transformer": lambda: TorchTransformer(), "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), - "vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit']), + "vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]), "memsave_vit": lambda: convert_to_memory_saving( - AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit']) + AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]) ), } from functools import partial -fused = lambda name: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name)) +fused = lambda name, kwargs: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name, **kwargs)) for m in hf_transformers_models: + if m in transformer_model_fns: + continue # Can't use lambdas in loops :') if not m.startswith('memsave_'): - hf_name = hf_transformers_models_map[m] - transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name) + hf_name, kwargs = hf_transformers_models_map[m] + transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name, **kwargs) else: - hf_name = hf_transformers_models_map[m.split('memsave_', 1)[1]] - transformer_model_fns[m] = partial(fused, hf_name) + hf_name, kwargs = hf_transformers_models_map[m.split('memsave_', 1)[1]] + transformer_model_fns[m] = partial(fused, hf_name, kwargs) class TorchTransformer(Module): @@ -364,6 +367,7 @@ def __init__(self, model_fn) -> None: """Init""" super().__init__() self.model = model_fn() + self.dec = self.model.config.is_encoder_decoder def forward(self, x): """Forward @@ -374,7 +378,11 @@ def forward(self, x): Returns: output: model output """ - return self.model(inputs_embeds=x, use_cache=False)["logits"].permute(0, 2, 1) + if self.dec: + out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, use_cache=False) + else: + out = self.model(inputs_embeds=x, use_cache=False) + return out.logits.permute(0, 2, 1) # LINEAR diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index 6da394d..a59b3dc 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -4,10 +4,14 @@ """ import torch.nn as nn -import transformers +import sys from memsave_torch.nn.functional import linearMemSave +transformers_imported = False +if 'transformers' in sys.modules: + import transformers + transformers_imported = True class MemSaveLinear(nn.Linear): """MemSaveLinear.""" @@ -45,7 +49,9 @@ def from_nn_Linear(cls, linear: nn.Linear): Returns: obj: The MemSaveLinear object """ - isTransformersConv1D = isinstance(linear, transformers.Conv1D) + isTransformersConv1D = False + if transformers_imported: + isTransformersConv1D = isinstance(linear, transformers.Conv1D) if isTransformersConv1D: # it only saves output features in the model (linear.nf); need to take input features from weight anyway # weight and bias are still defined diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index fa96c1d..6f26d80 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -7,7 +7,7 @@ """ import torch.nn as nn -import transformers +import sys from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d @@ -19,6 +19,10 @@ from memsave_torch.nn.MaxPool import MemSaveMaxPool2d from memsave_torch.nn.ReLU import MemSaveReLU +transformers_imported = False +if 'transformers' in sys.modules: + import transformers + transformers_imported = True def convert_to_memory_saving( model: nn.Module, @@ -55,10 +59,13 @@ def convert_to_memory_saving( Returns: memsavemodel (nn.Module): The converted memory saving model """ + linear_cls = nn.Linear + if transformers_imported: + linear_cls = (nn.Linear, transformers.Conv1D) layers = [ { "allowed": linear, - "cls": (nn.Linear, transformers.Conv1D), + "cls": linear_cls, "convert_fn": MemSaveLinear.from_nn_Linear, }, {"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU},