diff --git a/memsave_torch/util/visual_abstract/__init__.py b/memsave_torch/util/visual_abstract/__init__.py new file mode 100644 index 0000000..086201d --- /dev/null +++ b/memsave_torch/util/visual_abstract/__init__.py @@ -0,0 +1 @@ +"""This experiments aims to understand when inputs are stored by PyTorch's autodiff.""" diff --git a/memsave_torch/util/visual_abstract/experiments.md b/memsave_torch/util/visual_abstract/experiments.md new file mode 100644 index 0000000..9812c05 --- /dev/null +++ b/memsave_torch/util/visual_abstract/experiments.md @@ -0,0 +1,75 @@ +Use input of shape `(256, 8, 256, 256)` and size-preserving convolutions with `padding=1`, `kernel_size=3`. + +--- + +Peak memory used by the forward pass: + +- 1 layer: 1725.78515625 +- 2 layers: 2238.59375 +- 3 layers: 2750.390625 +- 4 layers: 3261.08984375 +- 5 layers: 3774.68359375 + +Roughly 500 MiB increase per layer added, consistent with the 512 MiB required to store an intermediate. + +--- + +Let's turn off `requires_grad` for the first layer: + +- 1 layer: 1724.75390625 +- 2 layers: 2237.5703125 +- 3 layers: 2749.796875 +- 4 layers: 3262.453125 +- 5 layers: 3773.8203125 + +Basically no change of effect at all! + +--- + +Let's turn off all `requires_grad`: + +- 1 layer: 1724.5390625 +- 2 layers: 2238.08203125 +- 3 layers: 2238.49609375 +- 4 layers: 2237.92578125 +- 5 layers: 2238.30078125 + +Now we can see that the original input, as well as two intermediates are stored at a time. + +--- + +Let's turn off all `requires_grad` except for the first layer: + +- 1 layer: 1725.52734375 +- 2 layers: 2236.26953125 +- 3 layers: 2749.359375 +- 4 layers: 3262.171875 +- 5 layers: 3773.9921875 + +Although we only want gradients for the first layer, we get the same memory consumption as if we wanted to compute gradients for all layers. + +--- + +Let's turn off all `requires_grad` except for the second layer: + +- 1 layer: 1725.0078125 +- 2 layers: 2238.3515625 +- 3 layers: 2750.6484375 +- 4 layers: 3262.36328125 +- 5 layers: 3774.34765625 + +Same behavior because we store in- and output of a convolution at a time + +--- + +Let's turn off all `requires_grad` except for the third layer: + +- 1 layer: 1725.171875 +- 2 layers: 2237.85546875 +- 3 layers: 2238.42578125 +- 4 layers: 2749.625 +- 5 layers: 3261.44921875 + +Notice the zero increase between 2-3 layers. + +--- diff --git a/memsave_torch/util/visual_abstract/gather_data.py b/memsave_torch/util/visual_abstract/gather_data.py new file mode 100644 index 0000000..9aa69d7 --- /dev/null +++ b/memsave_torch/util/visual_abstract/gather_data.py @@ -0,0 +1,41 @@ +"""Combine data from individual runs into data frames.""" + +from itertools import product +from os import makedirs, path + +from pandas import DataFrame + +HERE = path.abspath(__file__) +HEREDIR = path.dirname(HERE) +RAWDATADIR = path.join(HEREDIR, "raw") +DATADIR = path.join(HEREDIR, "gathered") +makedirs(RAWDATADIR, exist_ok=True) +makedirs(DATADIR, exist_ok=True) + +max_num_layers = 10 +requires_grads = ["all", "none", "4", "4+"] +implementations = ["torch", "ours"] + +if __name__ == "__main__": + for implementation, requires_grad in product(implementations, requires_grads): + if implementation == "ours" and requires_grad != "4": + continue + + layers = list(range(1, max_num_layers + 1)) + peakmems = [] + for num_layers in layers: + with open( + path.join( + RAWDATADIR, + f"peakmem_implementation_{implementation}_num_layers_{num_layers}_requires_grad_{requires_grad}.txt", + ), + "r", + ) as f: + peakmems.append(float(f.read())) + + df = DataFrame({"num_layers": layers, "peakmem": peakmems}) + savepath = path.join( + DATADIR, + f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv", + ) + df.to_csv(savepath, index=False) diff --git a/memsave_torch/util/visual_abstract/generate_data.py b/memsave_torch/util/visual_abstract/generate_data.py new file mode 100644 index 0000000..91dfecc --- /dev/null +++ b/memsave_torch/util/visual_abstract/generate_data.py @@ -0,0 +1,51 @@ +"""Launch all configurations of the memory benchmark.""" + +from itertools import product +from os import path +from subprocess import CalledProcessError, run +from typing import List + +HERE = path.abspath(__file__) +HEREDIR = path.dirname(HERE) +SCRIPT = path.join(HEREDIR, "run.py") + + +max_num_layers = 10 +requires_grads = ["all", "none", "4", "4+"] +implementations = ["torch", "ours"] + + +def _run(cmd: List[str]): + """Run the command and print the output/stderr if it fails. + + Args: + cmd: The command to run. + + Raises: + CalledProcessError: If the command fails. + """ + try: + print(f"Running command: {' '.join(cmd)}") + job = run(cmd, capture_output=True, text=True, check=True) + print(f"STDOUT:\n{job.stdout}") + print(f"STDERR:\n{job.stderr}") + except CalledProcessError as e: + print(f"STDOUT:\n{e.stdout}") + print(f"STDERR:\n{e.stderr}") + raise e + + +if __name__ == "__main__": + for implementation, requires_grad in product(implementations, requires_grads): + if implementation == "ours" and requires_grad != "4": + continue + for num_layers in range(1, max_num_layers + 1): + _run( + [ + "python", + SCRIPT, + f"--implementation={implementation}", + f"--num_layers={num_layers}", + f"--requires_grad={requires_grad}", + ] + ) diff --git a/memsave_torch/util/visual_abstract/plot_data.py b/memsave_torch/util/visual_abstract/plot_data.py new file mode 100644 index 0000000..c0fefa3 --- /dev/null +++ b/memsave_torch/util/visual_abstract/plot_data.py @@ -0,0 +1,82 @@ +"""Visualize memory consumpion.""" + +from os import path + +from matplotlib import pyplot as plt +from pandas import read_csv +from tueplots import bundles + +HERE = path.abspath(__file__) +HEREDIR = path.dirname(HERE) + +DATADIR = path.join(HEREDIR, "gathered") + +requires_grads = ["all", "none", "4+", "4"] +legend_entries = { + "all": "Fully differentiable", + "none": "Fully non-differentiable", + "4+": "Layers 4+ differentiable", + "4": "Layer 4 differentiable", + "4 (ours)": "Layer 4 differentiable (ours)", +} +markers = { + "all": "o", + "none": "x", + "4+": "<", + "4": ">", + "4 (ours)": "p", +} +linestyles = { + "all": "-", + "none": "-", + "4+": "dashed", + "4": "dashdot", + "4 (ours)": "dotted", +} + +with plt.rc_context(bundles.cvpr2024()): + fig, ax = plt.subplots() + ax.set_xlabel("Number of layers") + ax.set_ylabel("Peak memory [MiB]") + + markerstyle = {"markersize": 3.5, "fillstyle": "none"} + + # visualize PyTorch's behavior + implementation = "torch" + + for requires_grad in requires_grads: + df = read_csv( + path.join( + DATADIR, + f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv", + ) + ) + ax.plot( + df["num_layers"], + df["peakmem"], + label=legend_entries[requires_grad], + marker=markers[requires_grad], + linestyle=linestyles[requires_grad], + **markerstyle, + ) + + # visualize our layer's behavior + implementation, requires_grad = "ours", "4" + key = f"{requires_grad} ({implementation})" + df = read_csv( + path.join( + DATADIR, + f"peakmem_implementation_{implementation}_requires_grad_{requires_grad}.csv", + ) + ) + ax.plot( + df["num_layers"], + df["peakmem"], + label=legend_entries[key], + marker=markers[key], + linestyle=linestyles[key], + **markerstyle, + ) + + plt.legend() + plt.savefig(path.join(HEREDIR, "visual_abstract.pdf"), bbox_inches="tight") diff --git a/memsave_torch/util/visual_abstract/run.py b/memsave_torch/util/visual_abstract/run.py new file mode 100644 index 0000000..525ddd2 --- /dev/null +++ b/memsave_torch/util/visual_abstract/run.py @@ -0,0 +1,121 @@ +"""Measure forward pass peak memory and save to file.""" + +from argparse import ArgumentParser +from collections import OrderedDict +from os import makedirs, path + +from memory_profiler import memory_usage +from torch import manual_seed, rand +from torch.nn import Conv2d, Sequential + +from memsave_torch.nn import MemSaveConv2d + +HERE = path.abspath(__file__) +HEREDIR = path.dirname(HERE) +DATADIR = path.join(HEREDIR, "raw") +makedirs(DATADIR, exist_ok=True) + + +parser = ArgumentParser(description="Parse arguments.") +parser.add_argument("--num_layers", type=int, help="Number of layers.") +parser.add_argument( + "--requires_grad", + type=str, + choices=["all", "none", "4", "4+"], + help="Which layers are differentiable.", +) +parser.add_argument( + "--implementation", + type=str, + choices=["torch", "ours"], + help="Which implementation to use.", +) +args = parser.parse_args() + + +def main(): + manual_seed(0) + + # create the input + num_channels = 8 + spatial_size = 256 + batch_size = 256 + X = rand(batch_size, num_channels, spatial_size, spatial_size) + + # create the network + # preserve input size of convolutions + kernel_size = 3 + padding = 1 + + num_layers = args.num_layers + layers = OrderedDict() + for i in range(num_layers): + if args.implementation == "torch": + layers[f"conv{i}"] = Conv2d( + num_channels, num_channels, kernel_size, padding=padding, bias=False + ) + elif args.implementation == "ours": + layers[f"conv{i}"] = MemSaveConv2d( + num_channels, num_channels, kernel_size, padding=padding, bias=False + ) + else: + raise ValueError(f"Invalid implementation: {args.implementation}.") + + net = Sequential(layers) + + # set differentiability + if args.requires_grad == "none": + for param in net.parameters(): + param.requires_grad_(False) + elif args.requires_grad == "all": + for param in net.parameters(): + param.requires_grad_(True) + elif args.requires_grad == "4": + for name, param in net.named_parameters(): + param.requires_grad_("conv3" in name) + elif args.requires_grad == "4+": + for name, param in net.named_parameters(): + number = int(name.replace("conv", "").replace(".weight", "")) + param.requires_grad_(number >= 3) + else: + raise ValueError(f"Invalid requires_grad: {args.requires_grad}.") + + # turn off gradients for the first layer + # net.conv0.weight.requires_grad_(False) + + # turn of gradients for all layers + # for param in net.parameters(): + # param.requires_grad_(False) + + # turn off all gradients except for the first layer + # for name, param in net.named_parameters(): + # param.requires_grad_("conv0" in name) + + # turn off all gradients except for the second layer + # for name, param in net.named_parameters(): + # param.requires_grad_("conv1" in name) + + # turn off all gradients except for the third layer + # for name, param in net.named_parameters(): + # param.requires_grad_("conv2" in name) + + for name, param in net.named_parameters(): + print(f"{name} requires_grad = {param.requires_grad}") + + # forward pass + output = net(X) + assert output.shape == X.shape + + return output + + +if __name__ == "__main__": + max_usage = memory_usage(main, interval=1e-3, max_usage=True) + print(f"Peak mem: {max_usage}.") + filename = path.join( + DATADIR, + f"peakmem_implementation_{args.implementation}_num_layers_{args.num_layers}_requires_grad_{args.requires_grad}.txt", + ) + + with open(filename, "w") as f: + f.write(f"{max_usage}") diff --git a/memsave_torch/util/visual_abstract/visual_abstract.pdf b/memsave_torch/util/visual_abstract/visual_abstract.pdf new file mode 100644 index 0000000..aab07c4 Binary files /dev/null and b/memsave_torch/util/visual_abstract/visual_abstract.pdf differ