From 53b41c29e25dd264cab3a5892a2b192430e7014a Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Thu, 2 May 2024 03:12:30 +0530 Subject: [PATCH] add many hf models minor --- experiments/util/models.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/experiments/util/models.py b/experiments/util/models.py index de47240..e629ccd 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -296,23 +296,36 @@ def forward(self, loss_dict): # TRANSFORMER transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim) -hf_transformers_models = ["gpt2", "vit"] +hf_transformers_models_map = { + "gpt2": "gpt2", + "vit": "facebook/vit-mae-base", + "bert": "google-bert/bert-base-uncased", + "bart": "facebook/bart-base", + "roberta": "FacebookAI/roberta-base", + "t5": "google-t5/t5-base", + "flan-t5": "google/flan-t5-base", + "xlm-roberta": "FacebookAI/xlm-roberta-base", + "mistral-7b": "mistralai/Mistral-7B-v0.1", + "llama3-8b": "meta-llama/Meta-Llama-3-8B", +} +hf_transformers_models = list(hf_transformers_models_map.keys()) hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models) -hf_transformers_models_map = {"gpt2": "gpt2", "vit": "facebook/vit-mae-base"} transformer_model_fns = { - "gpt2": lambda: AutoModelForCausalLM.from_pretrained("gpt2"), - "memsave_gpt2": lambda: convert_to_memory_saving( - AutoModelForCausalLM.from_pretrained("gpt2") - ), - "vit": lambda: AutoModelForPreTraining.from_pretrained("facebook/vit-mae-base"), - "memsave_vit": lambda: convert_to_memory_saving( - AutoModelForPreTraining.from_pretrained("facebook/vit-mae-base") - ), "transformer": lambda: TorchTransformer(), "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), } +for m in hf_transformers_models: + if not m.startswith('memsave_'): + hf_name = hf_transformers_models_map[m] + transformer_model_fns[m] = lambda: AutoModelForPreTraining.from_pretrained(hf_name) + else: + hf_name = hf_transformers_models_map[m.split('memsave_', 1)[1]] + transformer_model_fns[m] = lambda: convert_to_memory_saving( + AutoModelForPreTraining.from_pretrained(hf_name) + ) + class TorchTransformer(Module): """Small model to wrap `torch.nn.Transformer`"""