Skip to content

Commit

Permalink
[REF] Improve MWE
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 30, 2024
1 parent a3747a4 commit 6e88811
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions experiments/visual_abstract/mwe.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
"""PyTorch's convolutions store their input when `weight.requires_grad=False`."""

from collections import OrderedDict
from time import sleep

from memory_profiler import memory_usage
from torch import rand
from torch.nn import Conv2d, Sequential

X = rand(256, 8, 256, 256)
NUM_LAYERS = 3
SHAPE_X = (256, 8, 256, 256) # shape of the input
MEM_X = 512 # requires 512 MiB storage
NUM_LAYERS = 5

# Create a deep linear CNN with size-preserving convolutions.
"""Create a deep linear CNN with size-preserving convolutions."""
layers = OrderedDict()
for i in range(NUM_LAYERS):
layers[f"conv{i}"] = Conv2d(8, 8, 3, padding=1, bias=False)
net = Sequential(layers)

def setup():
"""Create a deep linear CNN with size-preserving convolutions and an input."""
layers = OrderedDict()
for i in range(NUM_LAYERS):
layers[f"conv{i}"] = Conv2d(8, 8, 3, padding=1, bias=False)
return rand(*SHAPE_X), Sequential(layers)


# Consider three different scenarios: 1) no parameters are trainable, 2) all
# layers are trainable, 3) only the first layer is trainable
def non_trainable():
"""Forward pass through the CNN with all layers non-trainable."""
X, net = setup()
for i in range(NUM_LAYERS):
getattr(net, f"conv{i}").weight.requires_grad = False

Expand All @@ -33,6 +35,7 @@ def non_trainable():

def all_trainable():
"""Forward pass through the CNN with all layers trainable."""
X, net = setup()
for i in range(NUM_LAYERS):
getattr(net, f"conv{i}").weight.requires_grad = True

Expand All @@ -44,6 +47,7 @@ def all_trainable():

def first_trainable():
"""Forward pass through the CNN with first layer trainable."""
X, net = setup()
for i in range(NUM_LAYERS):
getattr(net, f"conv{i}").weight.requires_grad = i == 1

Expand All @@ -57,14 +61,20 @@ def first_trainable():
kwargs = {"interval": 1e-4, "max_usage": True} # memory profiler settingss

# measure memory and print
mem_offset = memory_usage(lambda: sleep(0.1), **kwargs)
print(f"Memory weights+input: {mem_offset:.1f} MiB.")

mem_non = memory_usage(non_trainable, **kwargs) - mem_offset
print(f"Memory non-trainable: {mem_non:.1f} MiB.")

mem_all = memory_usage(all_trainable, **kwargs) - mem_offset
print(f"Memory all-trainable: {mem_all:.1f} MiB.")

mem_first = memory_usage(first_trainable, **kwargs) - mem_offset
print(f"Memory first-trainable: {mem_first:.1f} MiB.")
mem_setup = memory_usage(setup, **kwargs)
print(f"Weights+input: {mem_setup:.1f} MiB.")

mem_non = memory_usage(non_trainable, **kwargs) - mem_setup
print(
f"Non-trainable: {mem_non:.1f} MiB (≈{mem_non / MEM_X:.1f} hidden activations)."
)

mem_all = memory_usage(all_trainable, **kwargs) - mem_setup
print(
f"All-trainable: {mem_all:.1f} MiB (≈{mem_all / MEM_X:.1f} hidden activations)."
)

mem_first = memory_usage(first_trainable, **kwargs) - mem_setup
print(
f"First-trainable: {mem_first:.1f} MiB (≈{mem_first / MEM_X:.1f} hidden activations)."
)

0 comments on commit 6e88811

Please sign in to comment.