Skip to content

Commit

Permalink
add vit, torch vanilla transformer to experiment models
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 30, 2024
1 parent ad09f00 commit 269ab70
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import List, Tuple

import torchvision.models as tvm
from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential
from transformers import AutoConfig, AutoModelForCausalLM
from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential, Transformer
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForPreTraining

from memsave_torch.nn import (
MemSaveConv2d,
Expand Down Expand Up @@ -176,7 +176,7 @@ def _convrelupool_model1(num_blocks=5) -> Module:
segmentation_models = prefix_in_pairs("memsave_", segmentation_models)
models_without_norm = ["deepmodel", "vgg16"]
models_without_norm = prefix_in_pairs("memsave_", models_without_norm)
transformers_models = ["gpt2"]
transformers_models = ["gpt2", "vit", "transformer"]
transformers_models = prefix_in_pairs("memsave_", transformers_models)

conv_model_fns = {
Expand Down Expand Up @@ -234,6 +234,12 @@ def _convrelupool_model1(num_blocks=5) -> Module:
"memsave_gpt2": lambda: convert_to_memory_saving(
AutoModelForCausalLM.from_pretrained("gpt2")
),
"vit": lambda: AutoModelForPreTraining.from_pretrained('facebook/vit-mae-base'),
"memsave_vit": lambda: convert_to_memory_saving(
AutoModelForPreTraining.from_pretrained('facebook/vit-mae-base')
),
"transformer": Transformer,
"memsave_transformer": lambda: convert_to_memory_saving(Transformer()),
# For paper
"memsave_resnet101_conv": lambda: convert_to_memory_saving_defaultsoff(
tvm.resnet101(), conv2d=True
Expand Down

0 comments on commit 269ab70

Please sign in to comment.