Skip to content

Commit

Permalink
add bn eval option in demo instead of hardcoding
Browse files Browse the repository at this point in the history
(cherry picked from commit 164732d)
  • Loading branch information
plutonium-239 committed Aug 22, 2024
1 parent f3741de commit 1d85817
Showing 1 changed file with 65 additions and 56 deletions.
121 changes: 65 additions & 56 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from experiments.util.models import prefix_in_pairs

estimators = ["time", "memory"]
estimators = ["memory"]
# estimators = ["memory"]
# estimators = ["time"]

# improvements can be either speedups or savings based on context
Expand All @@ -27,70 +27,68 @@

# repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`)
n_repeat = 5
batchnorm_eval = True # BatchNorm in eval mode

# CONV
# ============== CONV CONFIG ==============
# Valid choices for models are in models.conv_model_fns
# models = [
# "deepmodel",
# "resnet101",
# "resnet18",
# "vgg16", # "convnext_base",
# "fasterrcnn_resnet50_fpn_v2",
# "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2",
# "deeplabv3_resnet101",
# "fcn_resnet101",
# "efficientnet_v2_l",
# "mobilenet_v3_large",
# "resnext101_64x4d",
# ]
# models = prefix_in_pairs("memsave_", models)
# # models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"]
# batch_size = 64
# input_channels = 3
# input_HW = 224
# num_classes = 1000
# device = "cuda"
# architecture = "conv"
# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm', 'SurgicalFirst', 'SurgicalLast'])

# ============== TRANSFORMER CONFIG ==============
# Valid choices for models are in models.transformer_model_fns
models = [
"deepmodel",
"resnet101",
"resnet18",
"vgg16", # "convnext_base",
"fasterrcnn_resnet50_fpn_v2",
"ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2",
"deeplabv3_resnet101",
"fcn_resnet101",
"efficientnet_v2_l",
"mobilenet_v3_large",
"resnext101_64x4d",
"transformer",
"gpt2",
"bert",
"bart",
"roberta",
"t5",
"flan-t5",
# "xlm-roberta",
"mistral-7b",
"llama3-8b",
"phi3-4b",
]

# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"]
# models = ["resnet101", "memsave_resnet101_conv_full"]

models = prefix_in_pairs("memsave_", models)
# models = ["memsave_resnet101"]
batch_size = 64
input_channels = 3
input_HW = 224
num_classes = 1000
input_channels = 2048
input_HW = 256
num_classes = 5000
device = "cuda"
architecture = "conv"
architecture = "transformer"
cases = collect_results.select_cases(["All", "Input", "Norm"])

# LINEAR
# ============== LINEAR CONFIG ==============
# Valid choices for models are in models.linear_model_fns
# models = ['deeplinearmodel']
# models = ["deeplinearmodel"]
# models += [f"memsave_{m}" for m in models] # add memsave versions for each model
# batch_size = 32768
# input_channels = 3
# input_HW = 64
# batch_size = 1024
# input_channels = 3 # ignored
# input_HW = 64 # square of this is passed in estimate.py
# num_classes = 1000
# device = 'cuda'
# architecture = 'linear' # use high batch size

cases = [
None, # ALL
[ # INPUT
"grad_input",
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
[ # CONV
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
[ # NORM
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
],
]
# device = "cuda"
# architecture = "linear" # use high batch size
# cases = collect_results.select_cases(["All", "Input", "Linear"])


if __name__ == "__main__":
Expand All @@ -108,19 +106,30 @@
cases,
"results",
)
bn_eval_str = "--bn_eval" if batchnorm_eval else ""

for model in models:
B = batch_size
if model in prefix_in_pairs("memsave_", ["flan-t5"]):
B = 56
if model in prefix_in_pairs("memsave_", ["mistral-7b", "phi3-4b"]):
B = 16
if model in prefix_in_pairs("memsave_", ["llama3-8b"]):
B = 8
for estimate in estimators:
outputs = []

collector.clear_file(estimate)
for case in cases:
pbar.update()
pbar.set_description(f"{model} {estimate} case {case}")
case_display = collect_results.case_inv_mapping[
collect_results.make_case_str(case)
]
case_str = f"--case {' '.join(case)}" if case is not None else ""
pbar.set_description(f"{model} {estimate} case {case_display}")
cmd = (
f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
+ f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}"
+ f"--device {device} -B {B} -C_in {input_channels} -HW {input_HW} -n_class {num_classes} {bn_eval_str}"
)
proc = subprocess.run(shlex.split(cmd), capture_output=True)
assert proc.stderr in [
Expand Down

0 comments on commit 1d85817

Please sign in to comment.