Skip to content

Commit

Permalink
demo add transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed May 1, 2024
1 parent 53b41c2 commit 14613c5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
43 changes: 34 additions & 9 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`)
n_repeat = 5

# CONV
# ============== CONV CONFIG ==============
# Valid choices for models are in models.conv_model_fns
models = [
"deepmodel",
Expand All @@ -46,17 +46,36 @@

# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"]
# models = ["resnet101", "memsave_resnet101_conv_full"]
models = ["gpt2"]
# models = prefix_in_pairs("memsave_", models)
# batch_size = 64
# input_channels = 3
# input_HW = 224
# num_classes = 1000
# device = "cuda"
# architecture = "conv"

# ============== TRANSFORMER CONFIG ==============
# Valid choices for models are in models.transformer_model_fns
models = [
"gpt2",
"bert",
"bart",
"roberta",
"t5",
"flan-t5",
"xlm-roberta",
"mistral-7b",
"llama3-8b",
]
models = prefix_in_pairs("memsave_", models)
# models = ["memsave_resnet101"]
batch_size = 64
input_channels = 3
input_HW = 224
num_classes = 1000
batch_size = 8
input_channels = 2048
input_HW = 256
num_classes = 5000
device = "cuda"
architecture = "conv"
architecture = "transformer"

# LINEAR
# ============== LINEAR CONFIG ==============
# Valid choices for models are in models.linear_model_fns
# models = ['deeplinearmodel']
# models += [f"memsave_{m}" for m in models] # add memsave versions for each model
Expand Down Expand Up @@ -90,6 +109,12 @@
"no_grad_linear_weights",
"no_grad_linear_bias",
],
[ # LINEAR
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
]


Expand Down
1 change: 1 addition & 0 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"grad_input + no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Input",
"no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv",
"no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm",
"no_grad_conv_weights + no_grad_conv_bias + no_grad_norm_weights + no_grad_norm_bias": "Linear",
}


Expand Down

0 comments on commit 14613c5

Please sign in to comment.