Skip to content

Commit

Permalink
fix for encoder-decoder models + avoid loading transformers in `memsa…
Browse files Browse the repository at this point in the history
…ve_torch.nn`
  • Loading branch information
plutonium-239 committed May 2, 2024
1 parent a4ab2bd commit 9084a25
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 26 deletions.
4 changes: 2 additions & 2 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
models = [
"gpt2",
"bert",
"bart",
# "bart",
"roberta",
"t5",
"flan-t5",
Expand All @@ -68,7 +68,7 @@
# "llama3-8b",
]
models = prefix_in_pairs("memsave_", models)
batch_size = 8
batch_size = 64
input_channels = 2048
input_HW = 256
num_classes = 5000
Expand Down
48 changes: 28 additions & 20 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def get_transformers_config(model_name: str) -> AutoConfig:
"""
if model_name.startswith("memsave_"):
model_name = model_name.split("memsave_")[1]
model_hf_name = hf_transformers_models_map[model_name]
return AutoConfig.from_pretrained(model_hf_name)
model_hf_name, kwargs = hf_transformers_models_map[model_name]
return AutoConfig.from_pretrained(model_hf_name, **kwargs)


# CONV
Expand Down Expand Up @@ -296,41 +296,44 @@ def forward(self, loss_dict):
# TRANSFORMER
transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim)

# 'shortname': ('hf_repo_name', {extra_kwargs})
hf_transformers_models_map = {
"gpt2": "gpt2",
"vit": "facebook/vit-mae-base",
"bert": "google-bert/bert-base-uncased",
"bart": "facebook/bart-base",
"roberta": "FacebookAI/roberta-base",
"t5": "google-t5/t5-base",
"flan-t5": "google/flan-t5-base",
"xlm-roberta": "FacebookAI/xlm-roberta-base",
# "mistral-7b": "mistralai/Mistral-7B-v0.1", need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm
# "llama3-8b": "meta-llama/Meta-Llama-3-8B", GATED
"gpt2": ("gpt2", {}),
"vit": ("facebook/vit-mae-base", {}),
"bert": ("google-bert/bert-base-uncased", {'is_decoder': True}),
"bart": ("facebook/bart-base", {}),
"roberta": ("FacebookAI/roberta-base", {'is_decoder': True}),
"t5": ("google-t5/t5-base", {}),
"flan-t5": ("google/flan-t5-base", {}),
# "xlm-roberta": ("FacebookAI/xlm-roberta-base", {}), Needs work
# "mistral-7b": ("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm
# "llama3-8b": ("meta-llama/Meta-Llama-3-8B", {}), GATED
}
hf_transformers_models = list(hf_transformers_models_map.keys())
hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models)

transformer_model_fns = {
"transformer": lambda: TorchTransformer(),
"memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()),
"vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit']),
"vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]),
"memsave_vit": lambda: convert_to_memory_saving(
AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'])
AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0])
),
}

from functools import partial
fused = lambda name: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name))
fused = lambda name, kwargs: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name, **kwargs))

for m in hf_transformers_models:
if m in transformer_model_fns:
continue
# Can't use lambdas in loops :')
if not m.startswith('memsave_'):
hf_name = hf_transformers_models_map[m]
transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name)
hf_name, kwargs = hf_transformers_models_map[m]
transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name, **kwargs)
else:
hf_name = hf_transformers_models_map[m.split('memsave_', 1)[1]]
transformer_model_fns[m] = partial(fused, hf_name)
hf_name, kwargs = hf_transformers_models_map[m.split('memsave_', 1)[1]]
transformer_model_fns[m] = partial(fused, hf_name, kwargs)


class TorchTransformer(Module):
Expand Down Expand Up @@ -364,6 +367,7 @@ def __init__(self, model_fn) -> None:
"""Init"""
super().__init__()
self.model = model_fn()
self.dec = self.model.config.is_encoder_decoder

def forward(self, x):
"""Forward
Expand All @@ -374,7 +378,11 @@ def forward(self, x):
Returns:
output: model output
"""
return self.model(inputs_embeds=x, use_cache=False)["logits"].permute(0, 2, 1)
if self.dec:
out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, use_cache=False)
else:
out = self.model(inputs_embeds=x, use_cache=False)
return out.logits.permute(0, 2, 1)


# LINEAR
Expand Down
10 changes: 8 additions & 2 deletions memsave_torch/nn/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
"""

import torch.nn as nn
import transformers
import sys

from memsave_torch.nn.functional import linearMemSave

transformers_imported = False
if 'transformers' in sys.modules:
import transformers
transformers_imported = True

class MemSaveLinear(nn.Linear):
"""MemSaveLinear."""
Expand Down Expand Up @@ -45,7 +49,9 @@ def from_nn_Linear(cls, linear: nn.Linear):
Returns:
obj: The MemSaveLinear object
"""
isTransformersConv1D = isinstance(linear, transformers.Conv1D)
isTransformersConv1D = False
if transformers_imported:
isTransformersConv1D = isinstance(linear, transformers.Conv1D)
if isTransformersConv1D:
# it only saves output features in the model (linear.nf); need to take input features from weight anyway
# weight and bias are still defined
Expand Down
11 changes: 9 additions & 2 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import torch.nn as nn
import transformers
import sys

from memsave_torch.nn import functional # noqa: F401
from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
Expand All @@ -19,6 +19,10 @@
from memsave_torch.nn.MaxPool import MemSaveMaxPool2d
from memsave_torch.nn.ReLU import MemSaveReLU

transformers_imported = False
if 'transformers' in sys.modules:
import transformers
transformers_imported = True

def convert_to_memory_saving(
model: nn.Module,
Expand Down Expand Up @@ -55,10 +59,13 @@ def convert_to_memory_saving(
Returns:
memsavemodel (nn.Module): The converted memory saving model
"""
linear_cls = nn.Linear
if transformers_imported:
linear_cls = (nn.Linear, transformers.Conv1D)
layers = [
{
"allowed": linear,
"cls": (nn.Linear, transformers.Conv1D),
"cls": linear_cls,
"convert_fn": MemSaveLinear.from_nn_Linear,
},
{"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU},
Expand Down

0 comments on commit 9084a25

Please sign in to comment.