Skip to content

Commit

Permalink
add many hf models
Browse files Browse the repository at this point in the history
minor
  • Loading branch information
plutonium-239 committed May 1, 2024
1 parent 14e532e commit 53b41c2
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,23 +296,36 @@ def forward(self, loss_dict):
# TRANSFORMER
transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim)

hf_transformers_models = ["gpt2", "vit"]
hf_transformers_models_map = {
"gpt2": "gpt2",
"vit": "facebook/vit-mae-base",
"bert": "google-bert/bert-base-uncased",
"bart": "facebook/bart-base",
"roberta": "FacebookAI/roberta-base",
"t5": "google-t5/t5-base",
"flan-t5": "google/flan-t5-base",
"xlm-roberta": "FacebookAI/xlm-roberta-base",
"mistral-7b": "mistralai/Mistral-7B-v0.1",
"llama3-8b": "meta-llama/Meta-Llama-3-8B",
}
hf_transformers_models = list(hf_transformers_models_map.keys())
hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models)
hf_transformers_models_map = {"gpt2": "gpt2", "vit": "facebook/vit-mae-base"}

transformer_model_fns = {
"gpt2": lambda: AutoModelForCausalLM.from_pretrained("gpt2"),
"memsave_gpt2": lambda: convert_to_memory_saving(
AutoModelForCausalLM.from_pretrained("gpt2")
),
"vit": lambda: AutoModelForPreTraining.from_pretrained("facebook/vit-mae-base"),
"memsave_vit": lambda: convert_to_memory_saving(
AutoModelForPreTraining.from_pretrained("facebook/vit-mae-base")
),
"transformer": lambda: TorchTransformer(),
"memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()),
}

for m in hf_transformers_models:
if not m.startswith('memsave_'):
hf_name = hf_transformers_models_map[m]
transformer_model_fns[m] = lambda: AutoModelForPreTraining.from_pretrained(hf_name)
else:
hf_name = hf_transformers_models_map[m.split('memsave_', 1)[1]]
transformer_model_fns[m] = lambda: convert_to_memory_saving(
AutoModelForPreTraining.from_pretrained(hf_name)
)


class TorchTransformer(Module):
"""Small model to wrap `torch.nn.Transformer`"""
Expand Down

0 comments on commit 53b41c2

Please sign in to comment.