Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge transformers into main #8

Merged
merged 39 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3893c6c
add llm trial code
plutonium-239 Apr 16, 2024
f970dfe
`transformers.Conv1D`
plutonium-239 Apr 16, 2024
9a9801d
fix inputs for transformers
plutonium-239 Apr 16, 2024
5a34b84
minor
plutonium-239 Apr 16, 2024
e3e5476
directly pass input embeddings
plutonium-239 Apr 16, 2024
281866f
format
plutonium-239 Apr 20, 2024
f0aa31f
replace `transformers.Conv1D` with `MemSaveLinear`
plutonium-239 Apr 22, 2024
63a4f7c
replace `transformers.Conv1D` with `MemSaveLinear`
plutonium-239 Apr 22, 2024
80d4270
add vit, torch vanilla transformer to experiment models
plutonium-239 Apr 30, 2024
e5f728d
change estimate - add `transformer` as architecture
plutonium-239 May 1, 2024
c6f8c8e
improve transformers config handling
plutonium-239 May 1, 2024
c471abc
minor
plutonium-239 May 1, 2024
b470f16
various fixes + formatting
plutonium-239 May 1, 2024
0641fe9
add many hf models
plutonium-239 May 1, 2024
d55c3db
demo add transformers
plutonium-239 May 1, 2024
f30f8e2
change from `lambda`s inside loop to `functool.partial`s
plutonium-239 May 2, 2024
b92ee08
fix for encoder-decoder models + avoid loading transformers in `memsa…
plutonium-239 May 2, 2024
30f8ed6
format + minor
plutonium-239 May 5, 2024
420c027
fix "Do not use mutable data structures for argument defaults"
plutonium-239 May 5, 2024
d519e34
fix
plutonium-239 May 5, 2024
f690d13
minor
plutonium-239 Aug 22, 2024
d7622d1
lm_head is non trainable
plutonium-239 May 20, 2024
02ed7fa
misc fixes, code for making latex table
plutonium-239 Aug 22, 2024
13897da
qol
plutonium-239 May 20, 2024
9b8e7f5
fix for different model dtypes
plutonium-239 Aug 22, 2024
b9d31cc
adjust batch size for mistral-7b
plutonium-239 May 21, 2024
8c281f6
minor: correct cache_kws
plutonium-239 May 21, 2024
69d2bd1
minor printing changes
plutonium-239 May 21, 2024
d12a2d9
update optional deps
plutonium-239 May 21, 2024
f976b61
minor
plutonium-239 May 22, 2024
b057db2
new implementation of dropout
plutonium-239 Aug 22, 2024
f256f93
adding llama3 and phi3
plutonium-239 Aug 22, 2024
7cde54b
refactor cases to refer to a central location
plutonium-239 Aug 22, 2024
51537c6
format + prep for merge
plutonium-239 Aug 22, 2024
9ae8507
minor fixes
plutonium-239 Aug 22, 2024
b36baa6
add VLM architecture in estimate, refactor cases to refer to a centra…
plutonium-239 Jun 1, 2024
3d5ec71
format + prep for merge
plutonium-239 Jun 1, 2024
b78d341
Merge branch 'transformers' into vlms
plutonium-239 Aug 22, 2024
4d489fa
Merge pull request #7 from plutonium-239/vlms
plutonium-239 Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
*.txt
*.csv
!requirements.txt
torchviz-output/
torchview-output/

# generated docs
docs_src/_build/
Expand Down
18 changes: 13 additions & 5 deletions experiments/get_best_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from experiments.util.collect_results import case_mapping
from experiments.util.collect_results import case_inv_mapping


def main(base_dir: str):
Expand All @@ -16,17 +16,25 @@ def main(base_dir: str):
Args:
base_dir (str): The base results dir
"""
for device, arch in product(["cuda", "cpu"], ["linear", "conv"]):
# Don't recognize None as NaN
custom_na_values = pd._libs.parsers.STR_NA_VALUES - {"None"}
for device, arch in product(["cuda", "cpu"], ["linear", "conv", "transformer"]):
# usage stats
df = None
idx_col = ["model", "case"]
for fname in glob(os.path.join(base_dir, f"usage_stats-{arch}-{device}-*.csv")):
with open(fname) as f:
f.readline()
temp_df = pd.read_csv(f, index_col=idx_col)
# f.readline()
temp_df = pd.read_csv(
f,
index_col=idx_col,
header=1,
na_values=custom_na_values,
keep_default_na=False,
)
df = temp_df if df is None else pd.concat([df, temp_df])
if df is not None:
df = df.rename(index=case_mapping, level=1)
df = df.rename(index=case_inv_mapping, level=1)
df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024
df = df.drop(columns=["Memory Usage (MB)"])
best_results = df.groupby(idx_col).min()
Expand Down
106 changes: 56 additions & 50 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from experiments.util.models import prefix_in_pairs

estimators = ["time", "memory"]
estimators = ["memory"]
# estimators = ["memory"]
# estimators = ["time"]

# improvements can be either speedups or savings based on context
Expand All @@ -28,35 +28,55 @@
# repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`)
n_repeat = 5

# CONV
# ============== CONV CONFIG ==============
# Valid choices for models are in models.conv_model_fns
# models = [
# "deepmodel",
# "resnet101",
# "resnet18",
# "vgg16", # "convnext_base",
# "fasterrcnn_resnet50_fpn_v2",
# "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2",
# "deeplabv3_resnet101",
# "fcn_resnet101",
# "efficientnet_v2_l",
# "mobilenet_v3_large",
# "resnext101_64x4d",
# ]
# models = prefix_in_pairs("memsave_", models)
# batch_size = 64
# input_channels = 3
# input_HW = 224
# num_classes = 1000
# device = "cuda"
# architecture = "conv"
# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm'])

# ============== TRANSFORMER CONFIG ==============
# Valid choices for models are in models.transformer_model_fns
models = [
"deepmodel",
"resnet101",
"resnet18",
"vgg16", # "convnext_base",
"fasterrcnn_resnet50_fpn_v2",
"ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2",
"deeplabv3_resnet101",
"fcn_resnet101",
"efficientnet_v2_l",
"mobilenet_v3_large",
"resnext101_64x4d",
"transformer",
"gpt2",
"bert",
"bart",
"roberta",
"t5",
"flan-t5",
# "xlm-roberta",
"mistral-7b",
"llama3-8b",
"phi3-4b",
]

# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"]
# models = ["resnet101", "memsave_resnet101_conv_full"]

models = prefix_in_pairs("memsave_", models)
# models = ["memsave_resnet101"]
batch_size = 64
input_channels = 3
input_HW = 224
num_classes = 1000
input_channels = 2048
input_HW = 256
num_classes = 5000
device = "cuda"
architecture = "conv"
architecture = "transformer"
cases = collect_results.select_cases(["All", "Input", "Norm"])

# LINEAR
# ============== LINEAR CONFIG ==============
# Valid choices for models are in models.linear_model_fns
# models = ['deeplinearmodel']
# models += [f"memsave_{m}" for m in models] # add memsave versions for each model
Expand All @@ -66,31 +86,7 @@
# num_classes = 1000
# device = 'cuda'
# architecture = 'linear' # use high batch size

cases = [
None, # ALL
[ # INPUT
"grad_input",
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
[ # CONV
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
[ # NORM
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
],
]
# cases = collect_results.select_cases(['All', 'Input', 'Linear'])


if __name__ == "__main__":
Expand All @@ -110,17 +106,27 @@
)

for model in models:
B = batch_size
if model in prefix_in_pairs("memsave_", ["flan-t5"]):
B = 56
if model in prefix_in_pairs("memsave_", ["mistral-7b", "phi3-4b"]):
B = 16
if model in prefix_in_pairs("memsave_", ["llama3-8b"]):
B = 8
for estimate in estimators:
outputs = []

collector.clear_file(estimate)
for case in cases:
pbar.update()
pbar.set_description(f"{model} {estimate} case {case}")
case_display = collect_results.case_inv_mapping[
collect_results.make_case_str(case)
]
case_str = f"--case {' '.join(case)}" if case is not None else ""
pbar.set_description(f"{model} {estimate} case {case_display}")
cmd = (
f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
+ f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}"
+ f"--device {device} -B {B} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}"
)
proc = subprocess.run(shlex.split(cmd), capture_output=True)
assert proc.stderr in [
Expand Down
56 changes: 49 additions & 7 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,50 @@
],
}

case_mapping = {
"None": "All",
"grad_input + no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Input",
"no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv",
"no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm",
cases = {
"All": None, # ALL
"Input": [ # INPUT
"grad_input",
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
"Conv": [ # CONV
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
"Linear": [ # LINEAR
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
"Norm": [ # NORM
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
],
}


def select_cases(selected: List[str]) -> List[Union[List[str], None]]:
"""Helper function to return cases selected by their names

Args:
selected (List[str]): Which cases to select, strings can be keys of the cases table

Returns:
List[Union[List[str], None]]: Selected cases
"""
return [cases[s] for s in selected]


def make_case_str(case: Union[None, List[str]]) -> str:
"""Format case into a string

Expand All @@ -47,6 +83,9 @@ def make_case_str(case: Union[None, List[str]]) -> str:
return "None" if case is None else " + ".join(case)


case_inv_mapping = {make_case_str(v): k for k, v in cases.items()}


def hyperparam_str(args: SimpleNamespace) -> str:
"""Format hyperparams into a string

Expand Down Expand Up @@ -172,12 +211,15 @@ def _display_run(
"""
# print(f"{model} input ({input_channels},{input_HW},{input_HW}) {device}")
# print('='*78)
s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}"
if self.architecture == "conv":
s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}"
elif self.architecture == "transformer":
s = f"{model} input ({self.batch_size},{self.input_HW},{self.input_channels}(or model hidden size)) {self.device}"
print(s.center(78, "="))

for out, case in zip(outputs, self.cases):
print(
f"{strings[estimate][1]} ({case_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}"
f"{strings[estimate][1]} ({case_inv_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}"
)

# CODE ONLY APPLIES WITH OLD RUNDEMO.PY
Expand Down
51 changes: 48 additions & 3 deletions experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"no_grad_norm_bias",
"grad_input",
"no_grad_input",
"grad_embed_weights",
"no_grad_embed_weights",
]


Expand All @@ -62,7 +64,10 @@ def parse_case(case: Optional[List[str]]) -> Dict[str, bool]:


def skip_case_check(args: argparse.Namespace) -> bool:
"""Decide whether to skip the case (when case has grad_norm_* but model does not have any normalization layers)
"""Decide whether to skip the case:

1. when case has grad_norm_* but model does not have any normalization layers
2. when case has no_grad_embed_weights but no grad_input: there is a backward error (no input requires_grad)

Args:
args (argparse.Namespace): args
Expand All @@ -73,12 +78,16 @@ def skip_case_check(args: argparse.Namespace) -> bool:
invalid = False
if args.case is None:
return invalid
# 1.
for c in ["grad_norm_bias", "grad_norm_weights"]:
if c in args.case and args.model in models.models_without_norm:
invalid = True
for c in ["no_grad_norm_bias", "no_grad_norm_weights"]:
if c not in args.case and args.model in models.models_without_norm:
invalid = True
# 2.
if "no_grad_embed_weights" in args.case and "grad_input" not in args.case:
invalid = True
if invalid:
if args.print:
print("-1")
Expand Down Expand Up @@ -226,7 +235,7 @@ def estimate_mem_savings(
type=str,
required=True,
help="Which architecture to run",
choices=["conv", "linear"],
choices=["conv", "linear", "transformer", "VLM"],
)
parser.add_argument(
"--estimate",
Expand Down Expand Up @@ -269,23 +278,59 @@ def estimate_mem_savings(
input_shape = (args.input_channels, args.input_hw, args.input_hw)
models.conv_input_shape = input_shape
model_fn = models.conv_model_fns.get(args.model)
y_args = {"size": (batch_size,), "low": 0, "high": num_classes}
assert (
model_fn is not None
), f"Conv model name {args.model} not found, must be one of {list(models.conv_model_fns.keys())}"
elif args.architecture == "linear":
input_shape = [args.input_hw**2]
models.linear_input_shape = input_shape[0]
model_fn = models.linear_model_fns.get(args.model)
y_args = {"size": (batch_size,), "low": 0, "high": num_classes}
assert (
model_fn is not None
), f"Linear model name {args.model} not found, must be one of {list(models.linear_model_fns.keys())}"
elif args.architecture == "transformer":
vocab_dim = args.num_classes
embed_dim = args.input_channels
seq_len = args.input_hw
model_fn = models.transformer_model_fns.get(args.model)
if args.model in models.hf_transformers_models:
model_fn_orig = model_fn
model_fn = lambda: models.TransformersModelWrapper( # noqa: E731
model_fn_orig, args.model
)
config = models.get_transformers_config(args.model)
# as per transformers.PretrainedConfig these 2 should be present in all models:
vocab_dim = config.vocab_size
embed_dim = config.hidden_size
models.transformer_input_shape = (vocab_dim, embed_dim)
input_shape = [seq_len, embed_dim]
y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim}
assert (
model_fn is not None
), f"Transformer model name {args.model} not found, must be one of {list(models.transformer_model_fns.keys())}"
elif args.architecture == "VLM":
# model format: `vlm!<vis_model>!<vis_model_arch>!<llm>`
# eg: `vlm!vit!transformer!memsave_gpt2`
is_vlm, vis_model, vis_model_arch, llm = args.model.split("!")
assert is_vlm == "vlm"
assert vis_model_arch in ["transformer", "conv"]
model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731
config = models.get_transformers_config(llm)
vocab_dim = config.vocab_size
embed_dim = config.hidden_size
seq_len = (args.input_hw // 16) ** 2
y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim}
input_shape = (args.input_channels, args.input_hw, args.input_hw)
models.conv_input_shape = input_shape

loss_fn = CrossEntropyLoss

manual_seed(0) # make deterministic

x = rand(batch_size, *input_shape, device=dev)
y = randint(size=(batch_size,), low=0, high=num_classes, device=dev)
y = randint(**y_args, device=dev)
targets = None
if args.model in models.detection_models:
# pred is a dictionary of losses
Expand Down
Loading
Loading