diff --git a/experiments/visual_abstract/gather_data.py b/experiments/visual_abstract/gather_data.py index bf5e597..2c42172 100644 --- a/experiments/visual_abstract/gather_data.py +++ b/experiments/visual_abstract/gather_data.py @@ -5,8 +5,7 @@ from pandas import DataFrame -HERE = path.abspath(__file__) -HEREDIR = path.dirname(HERE) +HEREDIR = path.dirname(path.abspath(__file__)) RAWDATADIR = path.join(HEREDIR, "raw") DATADIR = path.join(HEREDIR, "gathered") makedirs(RAWDATADIR, exist_ok=True) @@ -15,32 +14,33 @@ max_num_layers = 10 requires_grads = ["all", "none", "4", "4+"] implementations = ["torch", "ours"] -architectures = ["linear", "conv", "norm_eval"] -architectures = ["norm_eval"] -architectures = ["linear"] +architectures = ["linear", "conv", "bn"] +modes = ["eval", "train"] if __name__ == "__main__": - for implementation, requires_grad, arch in product( - implementations, requires_grads, architectures + for implementation, requires_grad, architecture, mode in product( + implementations, requires_grads, architectures, modes ): if implementation == "ours" and requires_grad != "4": continue + if mode == "eval" and architecture != "bn": + continue - layers = list(range(1, max_num_layers + 1)) peakmems = [] + layers = list(range(1, max_num_layers + 1)) for num_layers in layers: - with open( - path.join( - RAWDATADIR, - f"peakmem_implementation_{arch}_{implementation}_num_layers_{num_layers}_requires_grad_{requires_grad}.txt", - ), - "r", - ) as f: + readpath = path.join( + RAWDATADIR, + f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" + + f"_num_layers_{num_layers}_requires_grad_{requires_grad}.txt", + ) + with open(readpath, "r") as f: peakmems.append(float(f.read())) df = DataFrame({"num_layers": layers, "peakmem": peakmems}) savepath = path.join( DATADIR, - f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv", + f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" + + f"_requires_grad_{requires_grad}.csv", ) df.to_csv(savepath, index=False) diff --git a/experiments/visual_abstract/generate_data.py b/experiments/visual_abstract/generate_data.py index 8a03319..f9b4255 100644 --- a/experiments/visual_abstract/generate_data.py +++ b/experiments/visual_abstract/generate_data.py @@ -12,12 +12,10 @@ max_num_layers = 10 requires_grads = ["all", "none", "4", "4+"] -# requires_grads = ["4+"] implementations = ["torch", "ours"] -# implementations = ["ours"] -architectures = ["linear", "conv", "norm_eval"] -architectures = ["norm_eval"] -architectures = ["linear"] +architectures = ["linear", "conv", "bn"] +modes = ["eval", "train"] +skip_existing = True def _run(cmd: List[str]): @@ -41,19 +39,23 @@ def _run(cmd: List[str]): if __name__ == "__main__": - for implementation, requires_grad, arch in product( - implementations, requires_grads, architectures + for implementation, requires_grad, architecture, mode in product( + implementations, requires_grads, architectures, modes ): if implementation == "ours" and requires_grad != "4": continue + if mode == "eval" and architecture != "bn": + continue for num_layers in range(1, max_num_layers + 1): _run( [ "python", SCRIPT, f"--implementation={implementation}", - f"--architecture={arch}", + f"--architecture={architecture}", f"--num_layers={num_layers}", f"--requires_grad={requires_grad}", + f"--mode={mode}", ] + + (["--skip_existing"] if skip_existing else []) ) diff --git a/experiments/visual_abstract/plot_data.py b/experiments/visual_abstract/plot_data.py index 7000486..35ebcd8 100644 --- a/experiments/visual_abstract/plot_data.py +++ b/experiments/visual_abstract/plot_data.py @@ -1,14 +1,13 @@ """Visualize memory consumpion.""" +from itertools import product from os import path from matplotlib import pyplot as plt from pandas import read_csv from tueplots import bundles -HERE = path.abspath(__file__) -HEREDIR = path.dirname(HERE) - +HEREDIR = path.dirname(path.abspath(__file__)) DATADIR = path.join(HEREDIR, "gathered") requires_grads = ["all", "none", "4+", "4"] @@ -33,57 +32,61 @@ "4": "dashdot", "4 (ours)": "dotted", } -architectures = ["linear", "conv", "norm_eval"] -architectures = ["norm_eval"] -architectures = ["conv"] +architectures = ["linear", "conv", "bn"] +modes = ["train", "eval"] + +if __name__ == "__main__": + for architecture, mode in product(architectures, modes): + if mode == "eval" and architecture != "bn": + continue -for arch in architectures: - with plt.rc_context(bundles.icml2024()): - plt.rcParams.update({"figure.figsize": (3.25, 2.5)}) - fig, ax = plt.subplots() - ax.set_xlabel("Number of layers") - ax.set_ylabel("Peak memory [MiB]") + with plt.rc_context(bundles.icml2024()): + # plt.rcParams.update({"figure.figsize": (3.25, 2.5)}) + fig, ax = plt.subplots() + ax.set_xlabel("Number of layers") + ax.set_ylabel("Peak memory [MiB]") - markerstyle = {"markersize": 3.5, "fillstyle": "none"} + markerstyle = {"markersize": 3.5, "fillstyle": "none"} - # visualize PyTorch's behavior - implementation = "torch" + # visualize PyTorch's behavior + implementation = "torch" - for requires_grad in requires_grads: - df = read_csv( - path.join( + for requires_grad in requires_grads: + readpath = path.join( DATADIR, - f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv", + f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" + + f"_requires_grad_{requires_grad}.csv", ) + df = read_csv(readpath) + ax.plot( + df["num_layers"], + df["peakmem"], + label=legend_entries[requires_grad], + marker=markers[requires_grad], + linestyle=linestyles[requires_grad], + **markerstyle, + ) + + # visualize our layer's behavior + implementation, requires_grad = "ours", "4" + key = f"{requires_grad} ({implementation})" + readpath = path.join( + DATADIR, + f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" + + f"_requires_grad_{requires_grad}.csv", ) + df = read_csv(readpath) ax.plot( df["num_layers"], df["peakmem"], - label=legend_entries[requires_grad], - marker=markers[requires_grad], - linestyle=linestyles[requires_grad], + label=legend_entries[key], + marker=markers[key], + linestyle=linestyles[key], **markerstyle, ) - # visualize our layer's behavior - implementation, requires_grad = "ours", "4" - key = f"{requires_grad} ({implementation})" - df = read_csv( - path.join( - DATADIR, - f"peakmem_implementation_{arch}_{implementation}_requires_grad_{requires_grad}.csv", + plt.legend() + plt.savefig( + path.join(HEREDIR, f"visual_abstract_{architecture}_{mode}.pdf"), + bbox_inches="tight", ) - ) - ax.plot( - df["num_layers"], - df["peakmem"], - label=legend_entries[key], - marker=markers[key], - linestyle=linestyles[key], - **markerstyle, - ) - - plt.legend() - plt.savefig( - path.join(HEREDIR, f"visual_abstract_{arch}.pdf"), bbox_inches="tight" - ) diff --git a/experiments/visual_abstract/run.py b/experiments/visual_abstract/run.py index bdb6a0b..9dde6c9 100644 --- a/experiments/visual_abstract/run.py +++ b/experiments/visual_abstract/run.py @@ -2,151 +2,178 @@ from argparse import ArgumentParser from collections import OrderedDict +from functools import partial from os import makedirs, path from memory_profiler import memory_usage -from torch import manual_seed, rand -from torch.nn import BatchNorm2d, Conv2d, Linear, Sequential - from memsave_torch.nn import MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear +from torch import allclose, manual_seed, rand, rand_like +from torch.autograd import grad +from torch.nn import BatchNorm2d, Conv2d, Linear, Sequential -HERE = path.abspath(__file__) -HEREDIR = path.dirname(HERE) +HEREDIR = path.dirname(path.abspath(__file__)) DATADIR = path.join(HEREDIR, "raw") makedirs(DATADIR, exist_ok=True) -parser = ArgumentParser(description="Parse arguments.") -parser.add_argument("--num_layers", type=int, help="Number of layers.") -parser.add_argument( - "--requires_grad", - type=str, - choices=["all", "none", "4", "4+"], - help="Which layers are differentiable.", -) -parser.add_argument( - "--implementation", - type=str, - choices=["torch", "ours"], - help="Which implementation to use.", -) -parser.add_argument( - "--architecture", - type=str, - choices=["linear", "conv", "norm_eval"], - help="Which architecture to use.", -) -args = parser.parse_args() - - -def main(): # noqa: C901 +def main( + architecture: str, + implementation: str, + mode: str, + num_layers: int, + requires_grad: str, +): # noqa: C901 """Runs exps for generating the data of the visual abstract""" manual_seed(0) # create the input + if architecture == "linear": + X = rand(512, 1024, 256) + elif architecture in {"conv", "bn"}: + X = rand(256, 8, 256, 256) + else: + raise ValueError(f"Invalid argument for architecture: {architecture}.") + assert X.numel() == 2**27 # (requires 512 MiB of storage) # create the network - # preserve input size of convolutions - kernel_size = 3 - padding = 1 - - num_layers = args.num_layers layers = OrderedDict() for i in range(num_layers): - if args.architecture == "linear": - num_channels = 1024 - spatial_size = 224 - batch_size = 256 - X = rand(batch_size, spatial_size, num_channels) - if args.implementation == "ours": - layers[f"linear{i}"] = MemSaveLinear( - num_channels, num_channels, bias=False - ) - else: - layers[f"linear{i}"] = Linear(num_channels, num_channels, bias=False) - elif args.architecture == "conv": - num_channels = 8 - spatial_size = 256 - batch_size = 256 - X = rand(batch_size, num_channels, spatial_size, spatial_size) - if args.implementation == "ours": - layers[f"conv{i}"] = MemSaveConv2d( - num_channels, num_channels, kernel_size, padding=padding, bias=False - ) - else: - layers[f"conv{i}"] = Conv2d( - num_channels, num_channels, kernel_size, padding=padding, bias=False - ) - elif args.architecture == "norm_eval": - num_channels = 8 - spatial_size = 256 - batch_size = 256 - X = rand(batch_size, num_channels, spatial_size, spatial_size) - if args.implementation == "ours": - layers[f"norm_eval{i}"] = MemSaveBatchNorm2d( - num_channels, - ) - else: - layers[f"norm_eval{i}"] = BatchNorm2d(num_channels) - layers[f"norm_eval{i}"].eval() + if architecture == "linear": + layer_cls = {"ours": MemSaveLinear, "torch": Linear}[implementation] + layers[f"{architecture}{i}"] = layer_cls(256, 256, bias=False) + elif architecture == "conv": + layer_cls = {"ours": MemSaveConv2d, "torch": Conv2d}[implementation] + layers[f"{architecture}{i}"] = layer_cls(8, 8, 3, padding=1, bias=False) + elif architecture == "bn": + layer_cls = {"ours": MemSaveBatchNorm2d, "torch": BatchNorm2d}[ + implementation + ] + layers[f"{architecture}{i}"] = layer_cls(8) else: - raise ValueError(f"Invalid args: {args}.") + raise ValueError(f"Invalid argument for architecture: {architecture}.") net = Sequential(layers) + # randomly initialize parameters + for param in net.parameters(): + param.data = rand_like(param) + + # randomly initialize running mean and std of BN + for module in net.modules(): + if isinstance(module, (BatchNorm2d, MemSaveBatchNorm2d)): + module.running_mean = rand_like(module.running_mean) + module.running_var = rand_like(module.running_var) + # set differentiability - if args.requires_grad == "none": + if requires_grad == "none": for param in net.parameters(): param.requires_grad_(False) - elif args.requires_grad == "all": + elif requires_grad == "all": for param in net.parameters(): param.requires_grad_(True) - elif args.requires_grad == "4": + elif requires_grad == "4": for name, param in net.named_parameters(): - param.requires_grad_(f"{args.architecture}3" in name) - elif args.requires_grad == "4+": + param.requires_grad_(f"{architecture}3" in name) + elif requires_grad == "4+": for name, param in net.named_parameters(): - number = int(name.replace(args.architecture, "").split(".")[0]) + number = int(name.replace(architecture, "").split(".")[0]) param.requires_grad_(number >= 3) else: - raise ValueError(f"Invalid requires_grad: {args.requires_grad}.") - - # turn off gradients for the first layer - # net.conv0.weight.requires_grad_(False) - - # turn of gradients for all layers - # for param in net.parameters(): - # param.requires_grad_(False) - - # turn off all gradients except for the first layer - # for name, param in net.named_parameters(): - # param.requires_grad_("conv0" in name) - - # turn off all gradients except for the second layer - # for name, param in net.named_parameters(): - # param.requires_grad_("conv1" in name) - - # turn off all gradients except for the third layer - # for name, param in net.named_parameters(): - # param.requires_grad_("conv2" in name) + raise ValueError(f"Invalid requires_grad: {requires_grad}.") for name, param in net.named_parameters(): print(f"{name} requires_grad = {param.requires_grad}") + # set mode + if mode == "eval": + net.eval() + elif mode == "train": + net.train() + else: + raise ValueError(f"Invalid mode: {mode}.") + # forward pass output = net(X) assert output.shape == X.shape - return output + return output, net + + +def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: str): + """Compare forward pass and gradients of PyTorch and Memsave.""" + output_ours, net_ours = main(architecture, "ours", mode, num_layers, requires_grad) + grad_args_ours = [p for p in net_ours.parameters() if p.requires_grad] + grad_ours = grad(output_ours.sum(), grad_args_ours) if grad_args_ours else [] + + output_torch, net_torch = main( + architecture, "ours", mode, num_layers, requires_grad + ) + grad_args_torch = [p for p in net_torch.parameters() if p.requires_grad] + grad_torch = grad(output_torch.sum(), grad_args_torch) if grad_args_torch else [] + + assert allclose(output_ours, output_torch) + assert len(list(net_ours.parameters())) == len(list(net_torch.parameters())) + for p1, p2 in zip(net_ours.parameters(), net_torch.parameters()): + assert allclose(p1, p2) + assert len(grad_ours) == len(grad_torch) + for g1, g2 in zip(grad_ours, grad_torch): + assert allclose(g1, g2) if __name__ == "__main__": - max_usage = memory_usage(main, interval=1e-3, max_usage=True) - print(f"Peak mem: {max_usage}.") + # arguments + parser = ArgumentParser(description="Parse arguments.") + parser.add_argument("--num_layers", type=int, help="Number of layers.") + parser.add_argument( + "--requires_grad", + type=str, + choices=["all", "none", "4", "4+"], + help="Which layers are differentiable.", + ) + parser.add_argument( + "--implementation", + type=str, + choices=["torch", "ours"], + help="Which implementation to use.", + ) + parser.add_argument( + "--architecture", + type=str, + choices=["linear", "conv", "bn"], + help="Which architecture to use.", + ) + parser.add_argument("--mode", type=str, help="Mode of the network.") + parser.add_argument( + "--skip_existing", action="store_true", help="Skip existing files." + ) + args = parser.parse_args() + filename = path.join( DATADIR, - f"peakmem_implementation_{args.architecture}_{args.implementation}_num_layers_{args.num_layers}_requires_grad_{args.requires_grad}.txt", + f"peakmem_{args.architecture}_mode_{args.mode}_implementation_" + + f"{args.implementation}_num_layers_{args.num_layers}" + + f"_requires_grad_{args.requires_grad}.txt", ) - - with open(filename, "w") as f: - f.write(f"{max_usage}") + if path.exists(filename) and args.skip_existing: + print(f"Skipping existing file: {filename}.") + else: + # measure memory + f = partial( + main, + num_layers=args.num_layers, + requires_grad=args.requires_grad, + implementation=args.implementation, + architecture=args.architecture, + mode=args.mode, + ) + max_usage = memory_usage(f, interval=1e-3, max_usage=True) + print(f"Peak mem: {max_usage}.") + + with open(filename, "w") as f: + f.write(f"{max_usage}") + + print("Performing equality check.") + check_equality( + args.architecture, args.mode, args.num_layers, args.requires_grad + ) + print("Equality check passed.")