Skip to content

Commit

Permalink
visual abstract - extend to other layers
Browse files Browse the repository at this point in the history
also minor improvements to latex table script and batch norm in eval
  • Loading branch information
plutonium-239 committed Jul 4, 2024
1 parent a7ce769 commit 8e68a88
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 67 deletions.
33 changes: 27 additions & 6 deletions experiments/best_results_to_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
help="Path to best_results<params>.csv generated by get_best_results",
)


args = parser.parse_args()

df = pd.read_csv(args.input)

df = df.set_index("model")
df = df[df["case"] != "Conv"]
df = df[df["case"] != "SurgicalLast"]

df["memsave"] = df.index.str.startswith("memsave_")
badi = df.index.map(
Expand Down Expand Up @@ -57,14 +59,26 @@ def _highlight(group, col_sort, col_bold):
"bart": "BART",
"roberta": "RoBERTa",
"gpt2": "GPT-2",
"t5": "T5",
"t5": "T5 \\cite{JMLR_t5}",
"flan-t5": "FLAN-T5",
"mistral-7b": "Mistral-7B",
"transformer": "Transformer",
"llama3-8b": "LLaMa3-8B",
"phi3-4b": "Phi3-4B",
"mistral-7b": "Mistral-7B \\cite{jiang2023mistral}",
"transformer": "Transformer \\cite{NIPS2017_3f5ee243_vaswaniattention}",
"llama3-8b": "LLaMa3-8B \\cite{touvron2023llama}",
"phi3-4b": "Phi3-4B \\cite{gunasekar2023textbooksPhi}",
# Conv
"deeplabv3_resnet101": "DeepLabv3 (RN101) \\cite{deeplabv3_chen2017rethinking}",
"efficientnet_v2_l": "EfficientNetv2-L \\cite{efficientnet_TanL19,efficientnetv2_TanL21}",
"fcn_resnet101": "FCN (RN101) \\cite{fcn}",
"mobilenet_v3_large": "MobileNetv3-L \\cite{mobilenetv3}",
"resnext101_64x4d": "ResNeXt101-64x4d \\cite{resnext_cvpr_XieGDTH17}",
"fasterrcnn_resnet50_fpn_v2": "Faster-RCNN (RN101) \\cite{faster_rcnn_RenHGS15}",
"ssdlite320_mobilenet_v3_large": "SSDLite (MobileNetv3-L) \\cite{mobilenetv2_Sandler_2018_CVPR}",
"vgg16": "VGG-16 \\cite{vgg_SimonyanZ14a}",
}

# import ipdb; ipdb.set_trace()
df2 = df2[df2.index.isin(names.keys(), level=0)]


def _format_name(n):
if n.startswith("memsave_"):
Expand Down Expand Up @@ -92,4 +106,11 @@ def _format_name(n):
short_index = df2_p.index.map(lambda t: "+ MemSave" if "+ MemSave" in t else t)
df2_p = df2_p.set_index(short_index)

print(df2_p.to_latex(na_rep="-", multicolumn_format="c"))
latex_str = df2_p.to_latex(na_rep="-", multicolumn_format="c")
final_str = ""
for line in latex_str.split("\n"):
add_line = line + "\n"
if line.startswith("+ MemSave"):
add_line += "\\midrule\n"
final_str += add_line
print(final_str)
13 changes: 13 additions & 0 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
"no_grad_norm_weights",
"no_grad_norm_bias",
],
"Everything": [ # INPUT
"grad_input"
],
"Nothing": [
"no_grad_conv_weights",
"no_grad_conv_bias",
"no_grad_linear_weights",
"no_grad_linear_bias",
"no_grad_norm_weights",
"no_grad_norm_bias",
],
"Conv": [ # CONV
"no_grad_linear_weights",
"no_grad_linear_bias",
Expand Down Expand Up @@ -217,6 +228,8 @@ def _display_run(
s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}"
elif self.architecture == "transformer":
s = f"{model} input ({self.batch_size},{self.input_HW},{self.input_channels}(or model hidden size)) {self.device}"
elif self.architecture == "linear":
s = f"{model} input ({self.batch_size},{self.input_HW**2}) {self.device}"
print(s.center(78, "="))

for out, case in zip(outputs, self.cases):
Expand Down
29 changes: 21 additions & 8 deletions experiments/util/visualize_graph_elementary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,67 @@
from experiments.util.estimate import parse_case
from experiments.util.measurements import separate_grad_arguments


def eval_bn(num_features):
m = torch.nn.BatchNorm2d(num_features)
m.eval()
return m


to_test = [
{
"name": "Linear2dims",
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 3), # weight sharing
"grads": ["All", "Input", "Linear"],
"grads": ["All", "Input", "Linear", "Everything", "Nothing"],
},
{
"name": "Conv2d",
"layer_fn": lambda: torch.nn.Conv2d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Conv"],
"grads": ["All", "Input", "Conv", "Everything", "Nothing"],
},
{
"name": "BatchNorm2d",
"layer_fn": lambda: torch.nn.BatchNorm2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm"],
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "BatchNorm2d_Eval",
"layer_fn": lambda: eval_bn(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "LayerNorm",
"layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input", "Norm"],
"grads": ["All", "Input", "Norm", "Everything", "Nothing"],
},
{
"name": "Dropout",
"layer_fn": lambda: torch.nn.Dropout(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input"],
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "MaxPool2d",
"layer_fn": lambda: torch.nn.MaxPool2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input"],
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "ReLU",
"layer_fn": lambda: torch.nn.ReLU(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input"],
"grads": ["All", "Input", "Everything", "Nothing"],
},
{
"name": "SiLU",
"layer_fn": lambda: torch.nn.SiLU(),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
"grads": ["All", "Input"],
"grads": ["All", "Input", "Everything", "Nothing"],
},
]

Expand Down
11 changes: 8 additions & 3 deletions experiments/visual_abstract/gather_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
max_num_layers = 10
requires_grads = ["all", "none", "4", "4+"]
implementations = ["torch", "ours"]
architectures = ["linear", "conv", "norm_eval"]
architectures = ["norm_eval"]
architectures = ["linear"]

if __name__ == "__main__":
for implementation, requires_grad in product(implementations, requires_grads):
for implementation, requires_grad, arch in product(
implementations, requires_grads, architectures
):
if implementation == "ours" and requires_grad != "4":
continue

Expand All @@ -27,7 +32,7 @@
with open(
path.join(
RAWDATADIR,
f"peakmem_implementation_{implementation}_num_layers_{num_layers}_requires_grad_{requires_grad}.txt",
f"peakmem_implementation_{arch}_{implementation}_num_layers_{num_layers}_requires_grad_{requires_grad}.txt",
),
"r",
) as f:
Expand All @@ -36,6 +41,6 @@
df = DataFrame({"num_layers": layers, "peakmem": peakmems})
savepath = path.join(
DATADIR,
f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv",
f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv",
)
df.to_csv(savepath, index=False)
10 changes: 9 additions & 1 deletion experiments/visual_abstract/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

max_num_layers = 10
requires_grads = ["all", "none", "4", "4+"]
# requires_grads = ["4+"]
implementations = ["torch", "ours"]
# implementations = ["ours"]
architectures = ["linear", "conv", "norm_eval"]
architectures = ["norm_eval"]
architectures = ["linear"]


def _run(cmd: List[str]):
Expand All @@ -36,7 +41,9 @@ def _run(cmd: List[str]):


if __name__ == "__main__":
for implementation, requires_grad in product(implementations, requires_grads):
for implementation, requires_grad, arch in product(
implementations, requires_grads, architectures
):
if implementation == "ours" and requires_grad != "4":
continue
for num_layers in range(1, max_num_layers + 1):
Expand All @@ -45,6 +52,7 @@ def _run(cmd: List[str]):
"python",
SCRIPT,
f"--implementation={implementation}",
f"--architecture={arch}",
f"--num_layers={num_layers}",
f"--requires_grad={requires_grad}",
]
Expand Down
69 changes: 38 additions & 31 deletions experiments/visual_abstract/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,50 +33,57 @@
"4": "dashdot",
"4 (ours)": "dotted",
}
architectures = ["linear", "conv", "norm_eval"]
architectures = ["norm_eval"]
architectures = ["conv"]

with plt.rc_context(bundles.cvpr2024()):
fig, ax = plt.subplots()
ax.set_xlabel("Number of layers")
ax.set_ylabel("Peak memory [MiB]")
for arch in architectures:
with plt.rc_context(bundles.icml2024()):
plt.rcParams.update({"figure.figsize": (3.25, 2.5)})
fig, ax = plt.subplots()
ax.set_xlabel("Number of layers")
ax.set_ylabel("Peak memory [MiB]")

markerstyle = {"markersize": 3.5, "fillstyle": "none"}
markerstyle = {"markersize": 3.5, "fillstyle": "none"}

# visualize PyTorch's behavior
implementation = "torch"
# visualize PyTorch's behavior
implementation = "torch"

for requires_grad in requires_grads:
for requires_grad in requires_grads:
df = read_csv(
path.join(
DATADIR,
f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv",
)
)
ax.plot(
df["num_layers"],
df["peakmem"],
label=legend_entries[requires_grad],
marker=markers[requires_grad],
linestyle=linestyles[requires_grad],
**markerstyle,
)

# visualize our layer's behavior
implementation, requires_grad = "ours", "4"
key = f"{requires_grad} ({implementation})"
df = read_csv(
path.join(
DATADIR,
f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv",
f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv",
)
)
ax.plot(
df["num_layers"],
df["peakmem"],
label=legend_entries[requires_grad],
marker=markers[requires_grad],
linestyle=linestyles[requires_grad],
label=legend_entries[key],
marker=markers[key],
linestyle=linestyles[key],
**markerstyle,
)

# visualize our layer's behavior
implementation, requires_grad = "ours", "4"
key = f"{requires_grad} ({implementation})"
df = read_csv(
path.join(
DATADIR,
f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv",
plt.legend()
plt.savefig(
path.join(HEREDIR, f"visual_abstract_{arch}.pdf"), bbox_inches="tight"
)
)
ax.plot(
df["num_layers"],
df["peakmem"],
label=legend_entries[key],
marker=markers[key],
linestyle=linestyles[key],
**markerstyle,
)

plt.legend()
plt.savefig(path.join(HEREDIR, "visual_abstract.pdf"), bbox_inches="tight")
Loading

0 comments on commit 8e68a88

Please sign in to comment.