Skip to content

Commit

Permalink
add plotting script for poster bar graphs + formatting
Browse files Browse the repository at this point in the history
- add `poster.csv` (which is just a filtered version of best results conv)
- add plot pdfs as wellg
  • Loading branch information
plutonium-239 committed Jul 31, 2024
1 parent 6e88811 commit 7854aac
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 12 deletions.
2 changes: 1 addition & 1 deletion experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`)
n_repeat = 5
batchnorm_eval = True # BatchNorm in eval mode
batchnorm_eval = True # BatchNorm in eval mode

# ============== CONV CONFIG ==============
# Valid choices for models are in models.conv_model_fns
Expand Down
112 changes: 112 additions & 0 deletions experiments/poster_barplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Script to plot bar graphs of savings as seen in the poster"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tueplots import bundles

df = pd.read_csv("results/paper/poster.csv")
df["Scaled M str"] = df["Scaled M"].apply(lambda x: f"{x:.2f}")
df["M str"] = df["Memory Usage (GB)"].apply(lambda x: f"{x:.2f}")

memsave_map = {False: "PyTorch", True: "+ MemSave"}
df["colors"] = df["memsave"].apply(lambda x: memsave_map[x])
color_map = {memsave_map[False]: "#F05F42", memsave_map[True]: "#00E1D2"}

# fig = px.bar(df, x='case', y='Scaled M', color='colors', text='M str',
# category_orders={'case': ['All', 'Input', 'Norm', 'SurgicalFirst']},
# barmode='group', facet_col='model_clean', facet_col_wrap=3,
# color_discrete_map={memsave_map[False]: '#F05F42', memsave_map[True]: '#00E1D2'},
# )

# fig.update_traces(width=0.6)
# fig.show()

width = 0.4
df["color_val"] = df["colors"].apply(lambda x: color_map[x])

names = {
"bert": "BERT",
"bart": "BART",
"roberta": "RoBERTa",
"gpt2": "GPT-2",
"t5": "T5",
"flan-t5": "FLAN-T5",
"mistral-7b": "Mistral-7B",
"transformer": "Transformer",
"llama3-8b": "LLaMa3-8B",
"phi3-4b": "Phi3-4B",
# Conv
"resnet101": "ResNet-101",
"deeplabv3_resnet101": "DeepLabv3 (RN101)",
"efficientnet_v2_l": "EfficientNetv2-L",
"fcn_resnet101": "FCN (RN101)",
"mobilenet_v3_large": "MobileNetv3-L",
"resnext101_64x4d": "ResNeXt101-64x4d",
"fasterrcnn_resnet50_fpn_v2": "Faster-RCNN (RN101)",
"ssdlite320_mobilenet_v3_large": "SSDLite (MobileNetv3-L)",
"vgg16": "VGG-16",
}

for chosen_model in ["resnet101", "efficientnet_v2_l", "mistral-7b", "t5"]:
df_model = df[df["model_clean"] == chosen_model]
with plt.rc_context(bundles.icml2024(column="full")):
fig, ax = plt.subplots()
# ax.set_xlabel("Case", size='large')
ax.set_ylabel("Peak memory [GiB]", size="large")
cases = []
for i, (case, group) in enumerate(df_model.groupby("case")):
cases.append(case)
for j, (memsave, mg) in enumerate(group.groupby("memsave")): # noqa: B007
r = ax.bar(
i + j * width,
mg["Memory Usage (GB)"],
width,
label=mg["colors"].item(),
color=mg["color_val"],
)
ax.bar_label(r, mg["Scaled M str"], padding=-20, size="x-large")
yoff = mg["Memory Usage (GB)"].item() * 0.05
if r[0].get_height() < 5:
ax.text(
i + j * width,
r[0].get_height() + yoff,
mg["colors"].item(),
ha="center",
va="bottom",
rotation="vertical",
size="x-large",
)
else:
ax.text(
i + j * width,
yoff,
mg["colors"].item(),
ha="center",
va="bottom",
rotation="vertical",
size="x-large",
)

# ax.bar(i + width, group['Scaled M'], width, label=group['M str'])
# ax.bar_label(rects, padding=3)

# for memsave, sub_group in group.groupby('memsave'):
# ax.plot(sub_group['case'], sub_group['Memory Usage (GB)'], marker='o', linestyle=linestyle, color=color, label=f'{model_clean} - {"memsave" if memsave else "no memsave"}')
# for j, txt in enumerate(sub_group['Scaled M str']):
# ax.annotate(txt, (sub_group['case'].iloc[j], sub_group['Memory Usage (GB)'].iloc[j]))

ax.set_xticks(np.arange(len(cases)) + width / 2, cases)
ax.tick_params(labelsize="x-large")
ax.set_title(names[chosen_model], fontsize="xx-large", fontweight=1000)
# handles, labels = ax.get_legend_handles_labels()
# unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
# ax.legend(*zip(*unique))

# ax.legend()
# fig.show()
# fig.waitforbuttonpress()
plt.savefig(
f"results/paper/poster_plot_{chosen_model}.pdf",
bbox_inches="tight",
)
22 changes: 11 additions & 11 deletions experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,6 @@
from os import makedirs, path

from memory_profiler import memory_usage
from memsave_torch.nn import (
MemSaveBatchNorm2d,
MemSaveConv1d,
MemSaveConv2d,
MemSaveConv3d,
MemSaveConvTranspose1d,
MemSaveConvTranspose2d,
MemSaveConvTranspose3d,
MemSaveLinear,
)
from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d
from torch import allclose, compile, manual_seed, rand, rand_like
from torch.autograd import grad
from torch.nn import (
Expand All @@ -31,6 +20,17 @@
Sequential,
)

from memsave_torch.nn import (
MemSaveBatchNorm2d,
MemSaveConv1d,
MemSaveConv2d,
MemSaveConv3d,
MemSaveConvTranspose1d,
MemSaveConvTranspose2d,
MemSaveConvTranspose3d,
MemSaveLinear,
)

HEREDIR = path.dirname(path.abspath(__file__))
DATADIR = path.join(HEREDIR, "raw")
makedirs(DATADIR, exist_ok=True)
Expand Down
45 changes: 45 additions & 0 deletions results/paper/poster.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
model_clean,model,case,Time Taken (s),Memory Usage (GB),Scaled T,Scaled M,memsave
efficientnet_v2_l,efficientnet_v2_l,All,0.7670447733253241,26.808164596557617,1.0,1.0,False
efficientnet_v2_l,efficientnet_v2_l,Input (BN Eval),0.5853275060653687,26.80623340606689,0.7880858167473107,1.0,False
efficientnet_v2_l,efficientnet_v2_l,Input,0.6206146785989404,26.808164596557617,0.8090983736300373,1.0,False
efficientnet_v2_l,efficientnet_v2_l,Norm,0.617624219506979,26.808164596557617,0.8051996975736228,1.0,False
efficientnet_v2_l,efficientnet_v2_l,Surgical,0.6779810208827257,26.808164596557617,0.8838871529539463,1.0,False
efficientnet_v2_l,memsave_efficientnet_v2_l,All,0.7840075613930821,26.808164596557617,1.022114469269141,1.0,True
efficientnet_v2_l,memsave_efficientnet_v2_l,Input (BN Eval),0.5787053339881822,10.448431015014648,0.7791697145036758,0.3897761709651427,True
efficientnet_v2_l,memsave_efficientnet_v2_l,Input,0.6277032848447561,18.642004013061523,0.8183398240542218,0.6953853161381778,True
efficientnet_v2_l,memsave_efficientnet_v2_l,Norm,0.625503615476191,18.642004013061523,0.8154721044046517,0.6953853161381778,True
efficientnet_v2_l,memsave_efficientnet_v2_l,Surgical,0.6915900399908423,22.05434799194336,0.9016292973259341,0.8226728059844617,True
mistral-7b,memsave_mistral-7b,All,2.548980531282723,37.66866302490234,1.2224257200714634,0.9039834132928382,True
mistral-7b,memsave_mistral-7b,Input,1.4442284815013409,32.168663024902344,0.6926149571485977,0.7719928308337864,True
mistral-7b,memsave_mistral-7b,Norm,1.4595470689237118,32.168663024902344,0.699961358986687,0.7719928308337864,True
mistral-7b,memsave_mistral-7b,Surgical,1.7122022127732637,33.543663024902344,0.8211282892004004,0.8049904764485495,True
mistral-7b,mistral-7b,All,2.085182346403599,41.66963958740234,1.0,1.0,False
mistral-7b,mistral-7b,Input,1.5489048743620517,34.200904846191406,0.7428150718010407,0.8207631547773478,False
mistral-7b,mistral-7b,Norm,1.5577463591471314,36.16963958740234,0.7470552212538351,0.8680094175409482,False
mistral-7b,mistral-7b,Surgical,1.6676594512537122,36.013389587402344,0.7997667226226014,0.8642596850847253,False
resnet101,memsave_resnet101,All,0.4432011162862181,8.557734489440918,0.9699415836748781,1.0933840979482345,True
resnet101,memsave_resnet101,Input (BN Eval),0.3338263579644263,1.4200773239135742,0.7620312946412207,0.18144620911864504,True
resnet101,memsave_resnet101,Input,0.3611809015274048,5.242091178894043,0.7904410949054845,0.6697589346887894,True
resnet101,memsave_resnet101,Norm,0.35775987803936,5.242091178894043,0.7829542163353497,0.6697589346887894,True
resnet101,memsave_resnet101,Surgical,0.3910391088575125,6.862116813659668,0.8557855081735115,0.876742485008887,True
resnet101,resnet101,All,0.4569358853623271,7.826832771301269,1.0,1.0,False
resnet101,resnet101,Input (BN Eval),0.3445483860559761,7.826436996459961,0.7865066566156454,1.0,False
resnet101,resnet101,Input,0.3788085076957941,7.826832771301269,0.8290189495522289,1.0,False
resnet101,resnet101,Norm,0.3715896429494023,7.826832771301269,0.8132205301729595,1.0,False
resnet101,resnet101,Surgical,0.4043164802715182,7.826832771301269,0.8848429139044652,1.0,False
resnext101_64x4d,memsave_resnext101_64x4d,All,0.6408866010606289,16.746541023254395,0.9845938795249098,1.105144109266399,True
resnext101_64x4d,memsave_resnext101_64x4d,Input,0.520104899071157,9.87433385848999,0.79903698952186,0.6516307983533247,True
resnext101_64x4d,memsave_resnext101_64x4d,Norm,0.5187254101037979,9.87433385848999,0.7969176810640359,0.6516307983533247,True
resnext101_64x4d,memsave_resnext101_64x4d,Surgical,0.5633453754708171,13.32355260848999,0.8654673195371679,0.8792529549431838,True
resnext101_64x4d,resnext101_64x4d,All,0.6509146708995104,15.15326452255249,1.0,1.0,False
resnext101_64x4d,resnext101_64x4d,Input,0.5331473043188453,15.15326452255249,0.8190740325180867,1.0,False
resnext101_64x4d,resnext101_64x4d,Norm,0.5326845943927765,15.15326452255249,0.8183631714071682,1.0,False
resnext101_64x4d,resnext101_64x4d,Surgical,0.5835767788812518,15.15326452255249,0.8965488181036031,1.0,False
t5,memsave_t5,All,1.9474648162722588,31.8445405960083,1.144910917764398,0.9535176092130268,True
t5,memsave_t5,Input,1.5153652485460043,22.7976655960083,0.8908803912473512,0.6826280168560118,True
t5,memsave_t5,Norm,1.5391185907647014,22.7976655960083,0.9048449366462439,0.6826280168560118,True
t5,memsave_t5,Surgical,1.5937594240531323,25.2830171585083,0.9369681801908518,0.7570466278824155,True
t5,t5,All,1.700974971987307,33.3969087600708,1.0,1.0,False
t5,t5,Input,1.372872439213097,25.9439058303833,0.8071091355383809,0.7768355453724425,False
t5,t5,Norm,1.3900730907917025,28.8500337600708,0.8172213663835587,0.8638534173127956,False
t5,t5,Surgical,1.469195489771664,27.7728853225708,0.8637372765427305,0.8316004790172659,False
Binary file added results/paper/poster_plot_efficientnet_v2_l.pdf
Binary file not shown.
Binary file added results/paper/poster_plot_mistral-7b.pdf
Binary file not shown.
Binary file added results/paper/poster_plot_resnet101.pdf
Binary file not shown.
Binary file added results/paper/poster_plot_t5.pdf
Binary file not shown.

0 comments on commit 7854aac

Please sign in to comment.