Skip to content

Commit

Permalink
Merge pull request #12 from plutonium-239/torchviz+latex
Browse files Browse the repository at this point in the history
merge torchviz+latex into main
  • Loading branch information
plutonium-239 authored Aug 22, 2024
2 parents 1eecbfa + 22957d9 commit 9377cb6
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 0 deletions.
116 changes: 116 additions & 0 deletions experiments/best_results_to_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Simple script to make a latex table from best results"""

import argparse

import pandas as pd

parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to best_results<params>.csv generated by get_best_results",
)


args = parser.parse_args()

df = pd.read_csv(args.input)

df = df.set_index("model")
df = df[df["case"] != "Conv"]
df = df[df["case"] != "SurgicalLast"]

df["memsave"] = df.index.str.startswith("memsave_")
badi = df.index.map(
lambda x: x.split("memsave_", 1)[1] if x.startswith("memsave") else x
)
badi.name = "model_clean"
df2 = df.reset_index().set_index(badi).sort_index()
divs = df2[(df2["case"] == "All") & (~df2["memsave"])]
df2["Scaled M"] = df2["Memory Usage (GB)"] / divs["Memory Usage (GB)"]
df2["Scaled T"] = df2["Time Taken (s)"] / divs["Time Taken (s)"]

df2["Memory [GiB]"] = df2.apply(
lambda x: f"{x['Memory Usage (GB)']:.2f} ({x['Scaled M']:.2f})", axis=1
)
df2["Time [s]"] = df2.apply(
lambda x: f"{x['Time Taken (s)']:.2f} ({x['Scaled T']:.2f})", axis=1
)


def _highlight(group, col_sort, col_bold):
for c_s, c_b in zip(col_sort, col_bold):
min_idx = group[c_s].argmin()
group[c_b] = [
f"\\textbf{{{group.iloc[i][c_b]}}}" if i == min_idx else group.iloc[i][c_b]
for i in range(len(group.index))
]
return group


df2 = df2.groupby(["model_clean", "case"]).apply(
_highlight, ["Memory Usage (GB)"], ["Memory [GiB]"]
)
# .apply(_highlight, ['Memory Usage (GB)', 'Time Taken (s)'], ['Memory [GiB]', 'Time [s]'])

names = {
"bert": "BERT",
"bart": "BART",
"roberta": "RoBERTa",
"gpt2": "GPT-2",
"t5": "T5 \\cite{JMLR_t5}",
"flan-t5": "FLAN-T5",
"mistral-7b": "Mistral-7B \\cite{jiang2023mistral}",
"transformer": "Transformer \\cite{NIPS2017_3f5ee243_vaswaniattention}",
"llama3-8b": "LLaMa3-8B \\cite{touvron2023llama}",
"phi3-4b": "Phi3-4B \\cite{gunasekar2023textbooksPhi}",
# Conv
"deeplabv3_resnet101": "DeepLabv3 (RN101) \\cite{deeplabv3_chen2017rethinking}",
"efficientnet_v2_l": "EfficientNetv2-L \\cite{efficientnet_TanL19,efficientnetv2_TanL21}",
"fcn_resnet101": "FCN (RN101) \\cite{fcn}",
"mobilenet_v3_large": "MobileNetv3-L \\cite{mobilenetv3}",
"resnext101_64x4d": "ResNeXt101-64x4d \\cite{resnext_cvpr_XieGDTH17}",
"fasterrcnn_resnet50_fpn_v2": "Faster-RCNN (RN101) \\cite{faster_rcnn_RenHGS15}",
"ssdlite320_mobilenet_v3_large": "SSDLite (MobileNetv3-L) \\cite{mobilenetv2_Sandler_2018_CVPR}",
"vgg16": "VGG-16 \\cite{vgg_SimonyanZ14a}",
}

# import ipdb; ipdb.set_trace()
df2 = df2[df2.index.isin(names.keys(), level=0)]


def _format_name(n):
if n.startswith("memsave_"):
mname = n.split("memsave_", 1)[1]
return f"{names[mname]} + MemSave"
return names[n]


ni = df2["model"].apply(_format_name)
df2 = df2.set_index(ni).sort_index().drop(
columns=[
"model",
"memsave",
"Memory Usage (GB)",
"Time Taken (s)",
"Scaled M",
"Scaled T",
]
) # fmt: skip

df2_p = df2.pivot_table(
index="model", columns="case", values=df2.columns[1:], aggfunc=lambda x: x
)

short_index = df2_p.index.map(lambda t: "+ MemSave" if "+ MemSave" in t else t)
df2_p = df2_p.set_index(short_index)

latex_str = df2_p.to_latex(na_rep="-", multicolumn_format="c")
final_str = ""
for line in latex_str.split("\n"):
add_line = line + "\n"
if line.startswith("+ MemSave"):
add_line += "\\midrule\n"
final_str += add_line
print(final_str)
39 changes: 39 additions & 0 deletions experiments/resnet_best_results_to_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Simple script to make a latex table from resnet results"""

import pandas as pd

df = pd.read_csv("results/resnet101_only/best_results-conv-cpu-usage_stats.csv")
df = df.set_index("model")
df = df.drop(columns=["Scaled M", "Scaled T"])
df = df.drop("memsave_resnet101_conv+relu+bn")
df = df[df["case"] != "SurgicalLast"]
df = df[df["case"] != "Conv"]

mem_div = df[df["case"] == "All"].loc["resnet101", "Memory Usage (GB)"]
time_div = df[df["case"] == "All"].loc["resnet101", "Time Taken (s)"]
df["Scaled M"] = df["Memory Usage (GB)"] / mem_div
df["Scaled T"] = df["Time Taken (s)"] / time_div

df["Memory [GiB]"] = df.apply(
lambda x: f"{x['Memory Usage (GB)']:.2f} ({x['Scaled M']:.2f})", axis=1
)
df["Time [s]"] = df.apply(
lambda x: f"{x['Time Taken (s)']:.2f} ({x['Scaled T']:.2f})", axis=1
)

df = df.drop(columns=["Scaled M", "Scaled T", "Memory Usage (GB)", "Time Taken (s)"])
df_p = df.pivot_table(
index="model", columns="case", values=df.columns[1:], aggfunc=lambda x: x
)

labels = {
"resnet101": "Default ResNet-101",
"memsave_resnet101_conv": "+ swap Convolution",
"memsave_resnet101_conv_full": "+ swap BatchNorm, ReLU",
}

df_p = df_p.rename(index=labels)
df_p = df_p.sort_index(ascending=False)

print(df_p["Memory [GiB]"].to_latex())
print(df_p["Time [s]"].to_latex())
69 changes: 69 additions & 0 deletions experiments/util/visualize_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# ruff: noqa
import argparse

import torch
from torchview import draw_graph
from torchviz import make_dot

from experiments.util import models
import torchvision.models as tvm
import transformers.models as tfm
from transformers import AutoConfig

to_test = {
"bert_encoder": ["bert", lambda model_full: model_full.bert.encoder.layer[0]],
"memsave_bert_encoder": [
"memsave_bert",
lambda model_full: model_full.bert.encoder.layer[0],
],
"bart_encoder": ["bart", lambda model_full: model_full.model.decoder.layers[0]],
"memsave_bart_encoder": [
"memsave_bart",
lambda model_full: model_full.model.decoder.layers[0],
],
"gpt2_layer": ["gpt2", lambda model_full: model_full.transformer.h[0]],
"memsave_gpt2_layer": [
"memsave_gpt2",
lambda model_full: model_full.transformer.h[0],
],
"t5_decoder": ["t5", lambda model_full: model_full.decoder.block[1]],
"memsave_t5_decoder": [
"memsave_t5",
lambda model_full: model_full.decoder.block[1],
],
}


def run_single(model, name, x):
y = model(x)
dot = make_dot(
y[0].mean(),
params=dict(model.named_parameters()),
show_attrs=True,
show_saved=True,
)
dot.render(filename=name, directory="torchviz-output")


if __name__ == "__main__":
# import argparse

# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--model", type=str, default="deeprelumodel", help="Which model to use"
# )

# args = parser.parse_args()

# models.conv_input_shape = (3, 64, 64)
models.transformer_input_shape = (5000, 1024)

for name in to_test:
model_name, block_fn = to_test[name]
config = models.get_transformers_config(model_name)

models.transformer_input_shape = (config.vocab_size, config.hidden_size)
x = torch.rand(7, 128, config.hidden_size)

model = models.transformer_model_fns.get(model_name)
run_single(block_fn(model()), name, x)
155 changes: 155 additions & 0 deletions experiments/util/visualize_graph_elementary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# ruff: noqa
import argparse

import torch
from torchview import draw_graph
from torchviz import make_dot

from tqdm import tqdm
import memsave_torch
from experiments.util.collect_results import select_cases
from experiments.util.estimate import parse_case
from experiments.util.measurements import separate_grad_arguments


def eval_bn(num_features):
m = torch.nn.BatchNorm2d(num_features)
m.eval()
return m


to_test = [
{
"name": "Linear2dims",
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 3), # weight sharing
"grads": ["All", "Input", "Linear", "Everything", "Nothing"],
},
{
"name": "Conv2d",
"layer_fn": lambda: torch.nn.Conv2d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Conv", "Everything", "Nothing"],
},
{
"name": "BatchNorm2d",
"layer_fn": lambda: torch.nn.BatchNorm2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "BatchNorm2d_Eval",
"layer_fn": lambda: eval_bn(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "LayerNorm",
"layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "Dropout",
"layer_fn": lambda: torch.nn.Dropout(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "MaxPool2d",
"layer_fn": lambda: torch.nn.MaxPool2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "ReLU",
"layer_fn": lambda: torch.nn.ReLU(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "SiLU",
"layer_fn": lambda: torch.nn.SiLU(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Everything", "Nothing"],
},
]


def run_single(model, x, name, dirname):
y = model(x)
dot = make_dot(
y.sum(),
params=dict(model.named_parameters()),
show_attrs=True,
show_saved=True,
)
dot.render(filename=name, directory=dirname)


def separate_grad_arguments_wrapper(
model,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_embed_weights: bool = False,
**kwargs,
):
return separate_grad_arguments(
model,
grad_linear_weights=grad_linear_weights,
grad_linear_bias=grad_linear_bias,
grad_conv_weights=grad_conv_weights,
grad_conv_bias=grad_conv_bias,
grad_norm_weights=grad_norm_weights,
grad_norm_bias=grad_norm_bias,
grad_embed_weights=grad_embed_weights,
)


if __name__ == "__main__":
for layer_to_test in (pbar := tqdm(to_test)):
pbar.set_description(layer_to_test["name"])
all_grad_cases = select_cases(layer_to_test["grads"])
for c_name, c in zip(layer_to_test["grads"], all_grad_cases):
grad_opts = parse_case(c)
x = layer_to_test["data_fn"]()
layer = layer_to_test["layer_fn"]()
memsave_layer = memsave_torch.nn.convert_to_memory_saving(
layer, clone_params=True
)
leafs, no_leafs = separate_grad_arguments_wrapper(
layer, **grad_opts
) # no weights differentiable

grad_input = False
x2 = x.clone()
grad_input = "grad_input" in grad_opts and grad_opts["grad_input"]

leafs = ([x] if grad_input else []) + leafs
no_leafs = ([] if grad_input else [x]) + no_leafs

for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)

# TODO: add grad weights case
leafs, no_leafs = separate_grad_arguments_wrapper(
memsave_layer, **grad_opts
) # no weights differentiable

leafs = ([x2] if grad_input else []) + leafs
no_leafs = ([] if grad_input else [x2]) + no_leafs

for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)

dirname = f"torchviz-output/elementary/{layer_to_test['name']}"
run_single(layer, x, c_name, dirname)
run_single(memsave_layer, x2, c_name + "_MemSave", dirname)

0 comments on commit 9377cb6

Please sign in to comment.