diff --git a/experiments/util/models.py b/experiments/util/models.py index 4f1acbf..8acd553 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -5,8 +5,8 @@ from typing import List, Tuple import torchvision.models as tvm -from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential -from transformers import AutoConfig, AutoModelForCausalLM +from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential, Transformer +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForPreTraining from memsave_torch.nn import ( MemSaveConv2d, @@ -176,7 +176,7 @@ def _convrelupool_model1(num_blocks=5) -> Module: segmentation_models = prefix_in_pairs("memsave_", segmentation_models) models_without_norm = ["deepmodel", "vgg16"] models_without_norm = prefix_in_pairs("memsave_", models_without_norm) -transformers_models = ["gpt2"] +transformers_models = ["gpt2", "vit", "transformer"] transformers_models = prefix_in_pairs("memsave_", transformers_models) conv_model_fns = { @@ -234,6 +234,12 @@ def _convrelupool_model1(num_blocks=5) -> Module: "memsave_gpt2": lambda: convert_to_memory_saving( AutoModelForCausalLM.from_pretrained("gpt2") ), + "vit": lambda: AutoModelForPreTraining.from_pretrained('facebook/vit-mae-base'), + "memsave_vit": lambda: convert_to_memory_saving( + AutoModelForPreTraining.from_pretrained('facebook/vit-mae-base') + ), + "transformer": Transformer, + "memsave_transformer": lambda: convert_to_memory_saving(Transformer()), # For paper "memsave_resnet101_conv": lambda: convert_to_memory_saving_defaultsoff( tvm.resnet101(), conv2d=True