diff --git a/experiments/visual_abstract/gather_data.py b/experiments/visual_abstract/gather_data.py index 0947f84..c531a4d 100644 --- a/experiments/visual_abstract/gather_data.py +++ b/experiments/visual_abstract/gather_data.py @@ -25,10 +25,11 @@ "conv_transpose3d", } modes = {"eval", "train"} +use_compiles = {False, True} if __name__ == "__main__": - for implementation, requires_grad, architecture, mode in product( - implementations, requires_grads, architectures, modes + for implementation, requires_grad, architecture, mode, use_compile in product( + implementations, requires_grads, architectures, modes, use_compiles ): if implementation == "ours" and requires_grad != "4": continue @@ -39,7 +40,8 @@ readpath = path.join( RAWDATADIR, f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" - + f"_num_layers_{num_layers}_requires_grad_{requires_grad}.txt", + + f"_num_layers_{num_layers}_requires_grad_{requires_grad}" + f"{'_use_compile' if use_compile else ''}.txt", ) with open(readpath, "r") as f: peakmems.append(float(f.read())) @@ -48,6 +50,7 @@ savepath = path.join( DATADIR, f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" - + f"_requires_grad_{requires_grad}.csv", + + f"_requires_grad_{requires_grad}{'_use_compile' if use_compile else ''}" + + ".csv", ) df.to_csv(savepath, index=False) diff --git a/experiments/visual_abstract/generate_data.py b/experiments/visual_abstract/generate_data.py index ab271a8..e6f28c3 100644 --- a/experiments/visual_abstract/generate_data.py +++ b/experiments/visual_abstract/generate_data.py @@ -25,6 +25,8 @@ "conv_transpose3d", } modes = {"eval", "train"} +use_compiles = {False, True} + skip_existing = True @@ -49,8 +51,10 @@ def _run(cmd: List[str]): if __name__ == "__main__": - configs = list(product(implementations, requires_grads, architectures, modes)) - for implementation, requires_grad, architecture, mode in tqdm(configs): + configs = list( + product(implementations, requires_grads, architectures, modes, use_compiles) + ) + for implementation, requires_grad, architecture, mode, use_compile in tqdm(configs): if implementation == "ours" and requires_grad != "4": continue @@ -66,4 +70,5 @@ def _run(cmd: List[str]): f"--mode={mode}", ] + (["--skip_existing"] if skip_existing else []) + + (["--use_compile"] if use_compile else []), ) diff --git a/experiments/visual_abstract/plot_data.py b/experiments/visual_abstract/plot_data.py index bc3130f..b300c43 100644 --- a/experiments/visual_abstract/plot_data.py +++ b/experiments/visual_abstract/plot_data.py @@ -43,8 +43,10 @@ "conv_transpose3d", } modes = {"train", "eval"} +use_compiles = {False, True} + if __name__ == "__main__": - for architecture, mode in product(architectures, modes): + for architecture, mode, use_compile in product(architectures, modes, use_compiles): with plt.rc_context(bundles.icml2024()): fig, ax = plt.subplots() ax.set_xlabel("Number of layers") @@ -59,7 +61,8 @@ readpath = path.join( DATADIR, f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" - + f"_requires_grad_{requires_grad}.csv", + + f"_requires_grad_{requires_grad}" + + f"{'_use_compile' if use_compile else ''}.csv", ) df = read_csv(readpath) ax.plot( @@ -77,7 +80,8 @@ readpath = path.join( DATADIR, f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}" - + f"_requires_grad_{requires_grad}.csv", + + f"_requires_grad_{requires_grad}" + + f"{'_use_compile' if use_compile else ''}.csv", ) df = read_csv(readpath) ax.plot( diff --git a/experiments/visual_abstract/run.py b/experiments/visual_abstract/run.py index add5462..22454a8 100644 --- a/experiments/visual_abstract/run.py +++ b/experiments/visual_abstract/run.py @@ -17,7 +17,7 @@ MemSaveLinear, ) from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d -from torch import allclose, manual_seed, rand, rand_like +from torch import allclose, compile, manual_seed, rand, rand_like from torch.autograd import grad from torch.nn import ( BatchNorm2d, @@ -42,6 +42,7 @@ def main( # noqa: C901 mode: str, num_layers: int, requires_grad: str, + use_compile: bool, ): """Runs exps for generating the data of the visual abstract""" manual_seed(0) @@ -137,6 +138,11 @@ def main( # noqa: C901 else: raise ValueError(f"Invalid mode: {mode}.") + # maybe compile + if use_compile: + print("Compiling model") + net = compile(net) + # forward pass output = net(X) assert output.shape == X.shape @@ -144,14 +150,18 @@ def main( # noqa: C901 return output, net -def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: str): +def check_equality( + architecture: str, mode: str, num_layers: int, requires_grad: str, use_compile: bool +): """Compare forward pass and gradients of PyTorch and Memsave.""" - output_ours, net_ours = main(architecture, "ours", mode, num_layers, requires_grad) + output_ours, net_ours = main( + architecture, "ours", mode, num_layers, requires_grad, use_compile + ) grad_args_ours = [p for p in net_ours.parameters() if p.requires_grad] grad_ours = grad(output_ours.sum(), grad_args_ours) if grad_args_ours else [] output_torch, net_torch = main( - architecture, "torch", mode, num_layers, requires_grad + architecture, "torch", mode, num_layers, requires_grad, use_compile ) grad_args_torch = [p for p in net_torch.parameters() if p.requires_grad] grad_torch = grad(output_torch.sum(), grad_args_torch) if grad_args_torch else [] @@ -202,13 +212,19 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: parser.add_argument( "--skip_existing", action="store_true", help="Skip existing files." ) + parser.add_argument( + "--use_compile", + action="store_true", + help="Compile the model before the forward pass.", + ) args = parser.parse_args() filename = path.join( DATADIR, f"peakmem_{args.architecture}_mode_{args.mode}_implementation_" + f"{args.implementation}_num_layers_{args.num_layers}" - + f"_requires_grad_{args.requires_grad}.txt", + + f"_requires_grad_{args.requires_grad}" + f"{'_use_compile' if args.use_compile else ''}.txt", ) if path.exists(filename) and args.skip_existing: print(f"Skipping existing file: {filename}.") @@ -221,6 +237,8 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: implementation=args.implementation, architecture=args.architecture, mode=args.mode, + # Memsave does not compile (TODO debug why) + use_compile=False if args.implementation == "ours" else args.use_compile, ) max_usage = memory_usage(f, interval=1e-4, max_usage=True) print(f"Peak mem: {max_usage}.") @@ -228,8 +246,14 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: with open(filename, "w") as f: f.write(f"{max_usage}") - print("Performing equality check.") - check_equality( - args.architecture, args.mode, args.num_layers, args.requires_grad - ) - print("Equality check passed.") + # Memsave is not compile-able (TODO debug why) + if not args.use_compile: + print("Performing equality check.") + check_equality( + args.architecture, + args.mode, + args.num_layers, + args.requires_grad, + args.use_compile, + ) + print("Equality check passed.")