Skip to content

Commit

Permalink
format + minor
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed May 5, 2024
1 parent 9084a25 commit 2d64730
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 32 deletions.
10 changes: 6 additions & 4 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@
# ============== TRANSFORMER CONFIG ==============
# Valid choices for models are in models.transformer_model_fns
models = [
"transformer",
"gpt2",
"bert",
# "bart",
"bart",
"roberta",
"t5",
"flan-t5",
Expand Down Expand Up @@ -109,9 +110,10 @@
"no_grad_linear_weights",
"no_grad_linear_bias",
],
[ # LINEAR
"no_grad_conv_weights",
"no_grad_conv_bias",
[ # LLM
"grad_input",
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
Expand Down
1 change: 1 addition & 0 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv",
"no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm",
"no_grad_conv_weights + no_grad_conv_bias + no_grad_norm_weights + no_grad_norm_bias": "Linear",
"grad_input + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "LLM",
}


Expand Down
64 changes: 40 additions & 24 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import itertools
import math
from typing import List, Tuple
from functools import partial
from typing import Any, List, Tuple

import torchvision.models as tvm
from torch.nn import (
Expand All @@ -15,7 +16,14 @@
Sequential,
Transformer,
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForPreTraining, AutoModel
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForPreTraining,
AutoModelForSeq2SeqLM,
BartForConditionalGeneration,
)

from memsave_torch.nn import (
MemSaveConv2d,
Expand Down Expand Up @@ -296,44 +304,52 @@ def forward(self, loss_dict):
# TRANSFORMER
transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim)

# 'shortname': ('hf_repo_name', {extra_kwargs})

class _HF_model:
def __init__(self, hf_name: str, extra_kwargs: dict = {}, model_cls: Any = AutoModelForCausalLM) -> None:
self.hf_name = hf_name
self.extra_kwargs = extra_kwargs
self.model_cls = model_cls


hf_transformers_models_map = {
"gpt2": ("gpt2", {}),
"vit": ("facebook/vit-mae-base", {}),
"bert": ("google-bert/bert-base-uncased", {'is_decoder': True}),
"bart": ("facebook/bart-base", {}),
"roberta": ("FacebookAI/roberta-base", {'is_decoder': True}),
"t5": ("google-t5/t5-base", {}),
"flan-t5": ("google/flan-t5-base", {}),
# "xlm-roberta": ("FacebookAI/xlm-roberta-base", {}), Needs work
# "mistral-7b": ("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm
# "llama3-8b": ("meta-llama/Meta-Llama-3-8B", {}), GATED
"gpt2": _HF_model("gpt2", {}),
"vit": _HF_model("facebook/vit-mae-base", {}, AutoModelForPreTraining),
"bert": _HF_model("google-bert/bert-base-uncased", {"is_decoder": True}),
"bart": _HF_model("facebook/bart-base", {}, BartForConditionalGeneration),
"roberta": _HF_model("FacebookAI/roberta-base", {"is_decoder": True}),
"t5": _HF_model("google-t5/t5-base", {}, AutoModelForSeq2SeqLM),
"flan-t5": _HF_model("google/flan-t5-base", {}, AutoModelForSeq2SeqLM),
"xlm-roberta": _HF_model("FacebookAI/xlm-roberta-base", {}, AutoModelForMaskedLM),
# "mistral-7b": _HF_model("mistralai/Mistral-7B-v0.1", {}), need to add transformers.models.mistral.modeling_mistral.MistralRMSNorm
# "llama3-8b": _HF_model("meta-llama/Meta-Llama-3-8B", {}), GATED
}
hf_transformers_models = list(hf_transformers_models_map.keys())
hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models)

transformer_model_fns = {
"transformer": lambda: TorchTransformer(),
"memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()),
"vit": lambda: AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0]),
"memsave_vit": lambda: convert_to_memory_saving(
AutoModelForPreTraining.from_pretrained(hf_transformers_models_map['vit'][0])
),
}

from functools import partial
fused = lambda name, kwargs: convert_to_memory_saving(AutoModelForCausalLM.from_pretrained(name, **kwargs))
fused = lambda cls, name, kwargs: convert_to_memory_saving( # noqa: E731
cls.from_pretrained(name, **kwargs)
)

for m in hf_transformers_models:
if m in transformer_model_fns:
continue
# Can't use lambdas in loops :')
if not m.startswith('memsave_'):
hf_name, kwargs = hf_transformers_models_map[m]
transformer_model_fns[m] = partial(AutoModelForCausalLM.from_pretrained, hf_name, **kwargs)
if not m.startswith("memsave_"):
props = hf_transformers_models_map[m]
transformer_model_fns[m] = partial(
props.model_cls.from_pretrained, props.hf_name, **props.extra_kwargs
)
else:
hf_name, kwargs = hf_transformers_models_map[m.split('memsave_', 1)[1]]
transformer_model_fns[m] = partial(fused, hf_name, kwargs)
props = hf_transformers_models_map[m.split("memsave_", 1)[1]]
transformer_model_fns[m] = partial(
fused, props.model_cls, props.hf_name, **props.extra_kwargs
)


class TorchTransformer(Module):
Expand Down
7 changes: 5 additions & 2 deletions memsave_torch/nn/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
This is done by not saving the inputs/weights if weight/inputs dont require grad.
"""

import torch.nn as nn
import sys

import torch.nn as nn

from memsave_torch.nn.functional import linearMemSave

transformers_imported = False
if 'transformers' in sys.modules:
if "transformers" in sys.modules:
import transformers

transformers_imported = True


class MemSaveLinear(nn.Linear):
"""MemSaveLinear."""

Expand Down
7 changes: 5 additions & 2 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
- BatchNorm2d
"""

import torch.nn as nn
import sys

import torch.nn as nn

from memsave_torch.nn import functional # noqa: F401
from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
Expand All @@ -20,10 +21,12 @@
from memsave_torch.nn.ReLU import MemSaveReLU

transformers_imported = False
if 'transformers' in sys.modules:
if "transformers" in sys.modules:
import transformers

transformers_imported = True


def convert_to_memory_saving(
model: nn.Module,
linear=True,
Expand Down

0 comments on commit 2d64730

Please sign in to comment.