-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from plutonium-239/torchviz+latex
merge torchviz+latex into main
- Loading branch information
Showing
4 changed files
with
379 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Simple script to make a latex table from best results""" | ||
|
||
import argparse | ||
|
||
import pandas as pd | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input", | ||
type=str, | ||
required=True, | ||
help="Path to best_results<params>.csv generated by get_best_results", | ||
) | ||
|
||
|
||
args = parser.parse_args() | ||
|
||
df = pd.read_csv(args.input) | ||
|
||
df = df.set_index("model") | ||
df = df[df["case"] != "Conv"] | ||
df = df[df["case"] != "SurgicalLast"] | ||
|
||
df["memsave"] = df.index.str.startswith("memsave_") | ||
badi = df.index.map( | ||
lambda x: x.split("memsave_", 1)[1] if x.startswith("memsave") else x | ||
) | ||
badi.name = "model_clean" | ||
df2 = df.reset_index().set_index(badi).sort_index() | ||
divs = df2[(df2["case"] == "All") & (~df2["memsave"])] | ||
df2["Scaled M"] = df2["Memory Usage (GB)"] / divs["Memory Usage (GB)"] | ||
df2["Scaled T"] = df2["Time Taken (s)"] / divs["Time Taken (s)"] | ||
|
||
df2["Memory [GiB]"] = df2.apply( | ||
lambda x: f"{x['Memory Usage (GB)']:.2f} ({x['Scaled M']:.2f})", axis=1 | ||
) | ||
df2["Time [s]"] = df2.apply( | ||
lambda x: f"{x['Time Taken (s)']:.2f} ({x['Scaled T']:.2f})", axis=1 | ||
) | ||
|
||
|
||
def _highlight(group, col_sort, col_bold): | ||
for c_s, c_b in zip(col_sort, col_bold): | ||
min_idx = group[c_s].argmin() | ||
group[c_b] = [ | ||
f"\\textbf{{{group.iloc[i][c_b]}}}" if i == min_idx else group.iloc[i][c_b] | ||
for i in range(len(group.index)) | ||
] | ||
return group | ||
|
||
|
||
df2 = df2.groupby(["model_clean", "case"]).apply( | ||
_highlight, ["Memory Usage (GB)"], ["Memory [GiB]"] | ||
) | ||
# .apply(_highlight, ['Memory Usage (GB)', 'Time Taken (s)'], ['Memory [GiB]', 'Time [s]']) | ||
|
||
names = { | ||
"bert": "BERT", | ||
"bart": "BART", | ||
"roberta": "RoBERTa", | ||
"gpt2": "GPT-2", | ||
"t5": "T5 \\cite{JMLR_t5}", | ||
"flan-t5": "FLAN-T5", | ||
"mistral-7b": "Mistral-7B \\cite{jiang2023mistral}", | ||
"transformer": "Transformer \\cite{NIPS2017_3f5ee243_vaswaniattention}", | ||
"llama3-8b": "LLaMa3-8B \\cite{touvron2023llama}", | ||
"phi3-4b": "Phi3-4B \\cite{gunasekar2023textbooksPhi}", | ||
# Conv | ||
"deeplabv3_resnet101": "DeepLabv3 (RN101) \\cite{deeplabv3_chen2017rethinking}", | ||
"efficientnet_v2_l": "EfficientNetv2-L \\cite{efficientnet_TanL19,efficientnetv2_TanL21}", | ||
"fcn_resnet101": "FCN (RN101) \\cite{fcn}", | ||
"mobilenet_v3_large": "MobileNetv3-L \\cite{mobilenetv3}", | ||
"resnext101_64x4d": "ResNeXt101-64x4d \\cite{resnext_cvpr_XieGDTH17}", | ||
"fasterrcnn_resnet50_fpn_v2": "Faster-RCNN (RN101) \\cite{faster_rcnn_RenHGS15}", | ||
"ssdlite320_mobilenet_v3_large": "SSDLite (MobileNetv3-L) \\cite{mobilenetv2_Sandler_2018_CVPR}", | ||
"vgg16": "VGG-16 \\cite{vgg_SimonyanZ14a}", | ||
} | ||
|
||
# import ipdb; ipdb.set_trace() | ||
df2 = df2[df2.index.isin(names.keys(), level=0)] | ||
|
||
|
||
def _format_name(n): | ||
if n.startswith("memsave_"): | ||
mname = n.split("memsave_", 1)[1] | ||
return f"{names[mname]} + MemSave" | ||
return names[n] | ||
|
||
|
||
ni = df2["model"].apply(_format_name) | ||
df2 = df2.set_index(ni).sort_index().drop( | ||
columns=[ | ||
"model", | ||
"memsave", | ||
"Memory Usage (GB)", | ||
"Time Taken (s)", | ||
"Scaled M", | ||
"Scaled T", | ||
] | ||
) # fmt: skip | ||
|
||
df2_p = df2.pivot_table( | ||
index="model", columns="case", values=df2.columns[1:], aggfunc=lambda x: x | ||
) | ||
|
||
short_index = df2_p.index.map(lambda t: "+ MemSave" if "+ MemSave" in t else t) | ||
df2_p = df2_p.set_index(short_index) | ||
|
||
latex_str = df2_p.to_latex(na_rep="-", multicolumn_format="c") | ||
final_str = "" | ||
for line in latex_str.split("\n"): | ||
add_line = line + "\n" | ||
if line.startswith("+ MemSave"): | ||
add_line += "\\midrule\n" | ||
final_str += add_line | ||
print(final_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Simple script to make a latex table from resnet results""" | ||
|
||
import pandas as pd | ||
|
||
df = pd.read_csv("results/resnet101_only/best_results-conv-cpu-usage_stats.csv") | ||
df = df.set_index("model") | ||
df = df.drop(columns=["Scaled M", "Scaled T"]) | ||
df = df.drop("memsave_resnet101_conv+relu+bn") | ||
df = df[df["case"] != "SurgicalLast"] | ||
df = df[df["case"] != "Conv"] | ||
|
||
mem_div = df[df["case"] == "All"].loc["resnet101", "Memory Usage (GB)"] | ||
time_div = df[df["case"] == "All"].loc["resnet101", "Time Taken (s)"] | ||
df["Scaled M"] = df["Memory Usage (GB)"] / mem_div | ||
df["Scaled T"] = df["Time Taken (s)"] / time_div | ||
|
||
df["Memory [GiB]"] = df.apply( | ||
lambda x: f"{x['Memory Usage (GB)']:.2f} ({x['Scaled M']:.2f})", axis=1 | ||
) | ||
df["Time [s]"] = df.apply( | ||
lambda x: f"{x['Time Taken (s)']:.2f} ({x['Scaled T']:.2f})", axis=1 | ||
) | ||
|
||
df = df.drop(columns=["Scaled M", "Scaled T", "Memory Usage (GB)", "Time Taken (s)"]) | ||
df_p = df.pivot_table( | ||
index="model", columns="case", values=df.columns[1:], aggfunc=lambda x: x | ||
) | ||
|
||
labels = { | ||
"resnet101": "Default ResNet-101", | ||
"memsave_resnet101_conv": "+ swap Convolution", | ||
"memsave_resnet101_conv_full": "+ swap BatchNorm, ReLU", | ||
} | ||
|
||
df_p = df_p.rename(index=labels) | ||
df_p = df_p.sort_index(ascending=False) | ||
|
||
print(df_p["Memory [GiB]"].to_latex()) | ||
print(df_p["Time [s]"].to_latex()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# ruff: noqa | ||
import argparse | ||
|
||
import torch | ||
from torchview import draw_graph | ||
from torchviz import make_dot | ||
|
||
from experiments.util import models | ||
import torchvision.models as tvm | ||
import transformers.models as tfm | ||
from transformers import AutoConfig | ||
|
||
to_test = { | ||
"bert_encoder": ["bert", lambda model_full: model_full.bert.encoder.layer[0]], | ||
"memsave_bert_encoder": [ | ||
"memsave_bert", | ||
lambda model_full: model_full.bert.encoder.layer[0], | ||
], | ||
"bart_encoder": ["bart", lambda model_full: model_full.model.decoder.layers[0]], | ||
"memsave_bart_encoder": [ | ||
"memsave_bart", | ||
lambda model_full: model_full.model.decoder.layers[0], | ||
], | ||
"gpt2_layer": ["gpt2", lambda model_full: model_full.transformer.h[0]], | ||
"memsave_gpt2_layer": [ | ||
"memsave_gpt2", | ||
lambda model_full: model_full.transformer.h[0], | ||
], | ||
"t5_decoder": ["t5", lambda model_full: model_full.decoder.block[1]], | ||
"memsave_t5_decoder": [ | ||
"memsave_t5", | ||
lambda model_full: model_full.decoder.block[1], | ||
], | ||
} | ||
|
||
|
||
def run_single(model, name, x): | ||
y = model(x) | ||
dot = make_dot( | ||
y[0].mean(), | ||
params=dict(model.named_parameters()), | ||
show_attrs=True, | ||
show_saved=True, | ||
) | ||
dot.render(filename=name, directory="torchviz-output") | ||
|
||
|
||
if __name__ == "__main__": | ||
# import argparse | ||
|
||
# parser = argparse.ArgumentParser() | ||
# parser.add_argument( | ||
# "--model", type=str, default="deeprelumodel", help="Which model to use" | ||
# ) | ||
|
||
# args = parser.parse_args() | ||
|
||
# models.conv_input_shape = (3, 64, 64) | ||
models.transformer_input_shape = (5000, 1024) | ||
|
||
for name in to_test: | ||
model_name, block_fn = to_test[name] | ||
config = models.get_transformers_config(model_name) | ||
|
||
models.transformer_input_shape = (config.vocab_size, config.hidden_size) | ||
x = torch.rand(7, 128, config.hidden_size) | ||
|
||
model = models.transformer_model_fns.get(model_name) | ||
run_single(block_fn(model()), name, x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# ruff: noqa | ||
import argparse | ||
|
||
import torch | ||
from torchview import draw_graph | ||
from torchviz import make_dot | ||
|
||
from tqdm import tqdm | ||
import memsave_torch | ||
from experiments.util.collect_results import select_cases | ||
from experiments.util.estimate import parse_case | ||
from experiments.util.measurements import separate_grad_arguments | ||
|
||
|
||
def eval_bn(num_features): | ||
m = torch.nn.BatchNorm2d(num_features) | ||
m.eval() | ||
return m | ||
|
||
|
||
to_test = [ | ||
{ | ||
"name": "Linear2dims", | ||
"layer_fn": lambda: torch.nn.Linear(3, 5), | ||
"data_fn": lambda: torch.rand(7, 12, 3), # weight sharing | ||
"grads": ["All", "Input", "Linear", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "Conv2d", | ||
"layer_fn": lambda: torch.nn.Conv2d(3, 5, 3), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Conv", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "BatchNorm2d", | ||
"layer_fn": lambda: torch.nn.BatchNorm2d(3), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Norm", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "BatchNorm2d_Eval", | ||
"layer_fn": lambda: eval_bn(3), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Norm", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "LayerNorm", | ||
"layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Norm", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "Dropout", | ||
"layer_fn": lambda: torch.nn.Dropout(), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "MaxPool2d", | ||
"layer_fn": lambda: torch.nn.MaxPool2d(3), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "ReLU", | ||
"layer_fn": lambda: torch.nn.ReLU(), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Everything", "Nothing"], | ||
}, | ||
{ | ||
"name": "SiLU", | ||
"layer_fn": lambda: torch.nn.SiLU(), | ||
"data_fn": lambda: torch.rand(7, 3, 12, 12), | ||
"grads": ["All", "Input", "Everything", "Nothing"], | ||
}, | ||
] | ||
|
||
|
||
def run_single(model, x, name, dirname): | ||
y = model(x) | ||
dot = make_dot( | ||
y.sum(), | ||
params=dict(model.named_parameters()), | ||
show_attrs=True, | ||
show_saved=True, | ||
) | ||
dot.render(filename=name, directory=dirname) | ||
|
||
|
||
def separate_grad_arguments_wrapper( | ||
model, | ||
grad_linear_weights: bool = True, | ||
grad_linear_bias: bool = True, | ||
grad_conv_weights: bool = True, | ||
grad_conv_bias: bool = True, | ||
grad_norm_weights: bool = True, | ||
grad_norm_bias: bool = True, | ||
grad_embed_weights: bool = False, | ||
**kwargs, | ||
): | ||
return separate_grad_arguments( | ||
model, | ||
grad_linear_weights=grad_linear_weights, | ||
grad_linear_bias=grad_linear_bias, | ||
grad_conv_weights=grad_conv_weights, | ||
grad_conv_bias=grad_conv_bias, | ||
grad_norm_weights=grad_norm_weights, | ||
grad_norm_bias=grad_norm_bias, | ||
grad_embed_weights=grad_embed_weights, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
for layer_to_test in (pbar := tqdm(to_test)): | ||
pbar.set_description(layer_to_test["name"]) | ||
all_grad_cases = select_cases(layer_to_test["grads"]) | ||
for c_name, c in zip(layer_to_test["grads"], all_grad_cases): | ||
grad_opts = parse_case(c) | ||
x = layer_to_test["data_fn"]() | ||
layer = layer_to_test["layer_fn"]() | ||
memsave_layer = memsave_torch.nn.convert_to_memory_saving( | ||
layer, clone_params=True | ||
) | ||
leafs, no_leafs = separate_grad_arguments_wrapper( | ||
layer, **grad_opts | ||
) # no weights differentiable | ||
|
||
grad_input = False | ||
x2 = x.clone() | ||
grad_input = "grad_input" in grad_opts and grad_opts["grad_input"] | ||
|
||
leafs = ([x] if grad_input else []) + leafs | ||
no_leafs = ([] if grad_input else [x]) + no_leafs | ||
|
||
for leaf in leafs: | ||
leaf.requires_grad_(True) | ||
for no_leaf in no_leafs: | ||
no_leaf.requires_grad_(False) | ||
|
||
# TODO: add grad weights case | ||
leafs, no_leafs = separate_grad_arguments_wrapper( | ||
memsave_layer, **grad_opts | ||
) # no weights differentiable | ||
|
||
leafs = ([x2] if grad_input else []) + leafs | ||
no_leafs = ([] if grad_input else [x2]) + no_leafs | ||
|
||
for leaf in leafs: | ||
leaf.requires_grad_(True) | ||
for no_leaf in no_leafs: | ||
no_leaf.requires_grad_(False) | ||
|
||
dirname = f"torchviz-output/elementary/{layer_to_test['name']}" | ||
run_single(layer, x, c_name, dirname) | ||
run_single(memsave_layer, x2, c_name + "_MemSave", dirname) |