Skip to content

Commit

Permalink
[ADD] Add option to compile PyTorch implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 10, 2024
1 parent 16626a5 commit 6702f5b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
11 changes: 7 additions & 4 deletions experiments/visual_abstract/gather_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
"conv_transpose3d",
}
modes = {"eval", "train"}
use_compiles = {False, True}

if __name__ == "__main__":
for implementation, requires_grad, architecture, mode in product(
implementations, requires_grads, architectures, modes
for implementation, requires_grad, architecture, mode, use_compile in product(
implementations, requires_grads, architectures, modes, use_compiles
):
if implementation == "ours" and requires_grad != "4":
continue
Expand All @@ -39,7 +40,8 @@
readpath = path.join(
RAWDATADIR,
f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}"
+ f"_num_layers_{num_layers}_requires_grad_{requires_grad}.txt",
+ f"_num_layers_{num_layers}_requires_grad_{requires_grad}"
f"{'_use_compile' if use_compile else ''}.txt",
)
with open(readpath, "r") as f:
peakmems.append(float(f.read()))
Expand All @@ -48,6 +50,7 @@
savepath = path.join(
DATADIR,
f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}"
+ f"_requires_grad_{requires_grad}.csv",
+ f"_requires_grad_{requires_grad}{'_use_compile' if use_compile else ''}"
+ ".csv",
)
df.to_csv(savepath, index=False)
9 changes: 7 additions & 2 deletions experiments/visual_abstract/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"conv_transpose3d",
}
modes = {"eval", "train"}
use_compiles = {False, True}

skip_existing = True


Expand All @@ -49,8 +51,10 @@ def _run(cmd: List[str]):


if __name__ == "__main__":
configs = list(product(implementations, requires_grads, architectures, modes))
for implementation, requires_grad, architecture, mode in tqdm(configs):
configs = list(
product(implementations, requires_grads, architectures, modes, use_compiles)
)
for implementation, requires_grad, architecture, mode, use_compile in tqdm(configs):
if implementation == "ours" and requires_grad != "4":
continue

Expand All @@ -66,4 +70,5 @@ def _run(cmd: List[str]):
f"--mode={mode}",
]
+ (["--skip_existing"] if skip_existing else [])
+ (["--use_compile"] if use_compile else []),
)
10 changes: 7 additions & 3 deletions experiments/visual_abstract/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
"conv_transpose3d",
}
modes = {"train", "eval"}
use_compiles = {False, True}

if __name__ == "__main__":
for architecture, mode in product(architectures, modes):
for architecture, mode, use_compile in product(architectures, modes, use_compiles):
with plt.rc_context(bundles.icml2024()):
fig, ax = plt.subplots()
ax.set_xlabel("Number of layers")
Expand All @@ -59,7 +61,8 @@
readpath = path.join(
DATADIR,
f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}"
+ f"_requires_grad_{requires_grad}.csv",
+ f"_requires_grad_{requires_grad}"
+ f"{'_use_compile' if use_compile else ''}.csv",
)
df = read_csv(readpath)
ax.plot(
Expand All @@ -77,7 +80,8 @@
readpath = path.join(
DATADIR,
f"peakmem_{architecture}_mode_{mode}_implementation_{implementation}"
+ f"_requires_grad_{requires_grad}.csv",
+ f"_requires_grad_{requires_grad}"
+ f"{'_use_compile' if use_compile else ''}.csv",
)
df = read_csv(readpath)
ax.plot(
Expand Down
44 changes: 34 additions & 10 deletions experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MemSaveLinear,
)
from memsave_torch.nn.ConvTranspose1d import MemSaveConvTranspose1d

Check failure on line 19 in experiments/visual_abstract/run.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

experiments/visual_abstract/run.py:19:46: F811 Redefinition of unused `MemSaveConvTranspose1d` from line 14
from torch import allclose, manual_seed, rand, rand_like
from torch import allclose, compile, manual_seed, rand, rand_like
from torch.autograd import grad
from torch.nn import (
BatchNorm2d,
Expand All @@ -42,6 +42,7 @@ def main( # noqa: C901
mode: str,
num_layers: int,
requires_grad: str,
use_compile: bool,
):
"""Runs exps for generating the data of the visual abstract"""
manual_seed(0)
Expand Down Expand Up @@ -137,21 +138,30 @@ def main( # noqa: C901
else:
raise ValueError(f"Invalid mode: {mode}.")

# maybe compile
if use_compile:
print("Compiling model")
net = compile(net)

# forward pass
output = net(X)
assert output.shape == X.shape

return output, net


def check_equality(architecture: str, mode: str, num_layers: int, requires_grad: str):
def check_equality(
architecture: str, mode: str, num_layers: int, requires_grad: str, use_compile: bool
):
"""Compare forward pass and gradients of PyTorch and Memsave."""
output_ours, net_ours = main(architecture, "ours", mode, num_layers, requires_grad)
output_ours, net_ours = main(
architecture, "ours", mode, num_layers, requires_grad, use_compile
)
grad_args_ours = [p for p in net_ours.parameters() if p.requires_grad]
grad_ours = grad(output_ours.sum(), grad_args_ours) if grad_args_ours else []

output_torch, net_torch = main(
architecture, "torch", mode, num_layers, requires_grad
architecture, "torch", mode, num_layers, requires_grad, use_compile
)
grad_args_torch = [p for p in net_torch.parameters() if p.requires_grad]
grad_torch = grad(output_torch.sum(), grad_args_torch) if grad_args_torch else []
Expand Down Expand Up @@ -202,13 +212,19 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad:
parser.add_argument(
"--skip_existing", action="store_true", help="Skip existing files."
)
parser.add_argument(
"--use_compile",
action="store_true",
help="Compile the model before the forward pass.",
)
args = parser.parse_args()

filename = path.join(
DATADIR,
f"peakmem_{args.architecture}_mode_{args.mode}_implementation_"
+ f"{args.implementation}_num_layers_{args.num_layers}"
+ f"_requires_grad_{args.requires_grad}.txt",
+ f"_requires_grad_{args.requires_grad}"
f"{'_use_compile' if args.use_compile else ''}.txt",
)
if path.exists(filename) and args.skip_existing:
print(f"Skipping existing file: {filename}.")
Expand All @@ -221,15 +237,23 @@ def check_equality(architecture: str, mode: str, num_layers: int, requires_grad:
implementation=args.implementation,
architecture=args.architecture,
mode=args.mode,
# Memsave does not compile (TODO debug why)
use_compile=False if args.implementation == "ours" else args.use_compile,
)
max_usage = memory_usage(f, interval=1e-4, max_usage=True)
print(f"Peak mem: {max_usage}.")

with open(filename, "w") as f:
f.write(f"{max_usage}")

print("Performing equality check.")
check_equality(
args.architecture, args.mode, args.num_layers, args.requires_grad
)
print("Equality check passed.")
# Memsave is not compile-able (TODO debug why)
if not args.use_compile:
print("Performing equality check.")
check_equality(
args.architecture,
args.mode,
args.num_layers,
args.requires_grad,
args.use_compile,
)
print("Equality check passed.")

0 comments on commit 6702f5b

Please sign in to comment.