Skip to content

Commit

Permalink
bn_eval utils, minor formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Jul 4, 2024
1 parent a354d9b commit 0df6e44
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
10 changes: 10 additions & 0 deletions experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def estimate_mem_savings(
help=f"Which case to run, allowed values are {allowed_cases}",
)
parser.add_argument("--device", type=str, default="cpu", help="torch device name")
parser.add_argument(
"--bn_eval",
action="store_true",
default=False,
help="Set all BN layers to eval mode",
)
parser.add_argument(
"--print",
action="store_true",
Expand Down Expand Up @@ -364,6 +370,10 @@ def estimate_mem_savings(
loss_fn_orig = loss_fn
loss_fn = lambda: models.SegmentationLossWrapper(loss_fn_orig) # noqa: E731

if args.bn_eval:
model_fn_orig_bn = model_fn
model_fn = lambda: models.set_BN_to_eval(model_fn_orig_bn()) # noqa: E731

# warm-up
# with redirect_stdout(open(devnull, "w")):
# estimate_speedup(model_fn, loss_fn, x, y, dev, vjp_speedups[:1])
Expand Down
8 changes: 6 additions & 2 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,20 @@ def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]:
raise ValueError(f"arch={arch} not in allowed architectures")


def set_BN_to_eval(model: Module):
"""Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()) in-place.
def set_BN_to_eval(model: Module) -> Module:
"""Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()).
Args:
model (Module): Input model
Returns:
Module: Model with BN layers in eval mode
"""
known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d)
for layer in model.modules():
if isinstance(layer, known_bn_layers):
layer.eval()
return model


# CONV
Expand Down
7 changes: 4 additions & 3 deletions experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
from os import makedirs, path

from memory_profiler import memory_usage
from memsave_torch.nn import MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear
from torch import allclose, manual_seed, rand, rand_like
from torch.autograd import grad
from torch.nn import BatchNorm2d, Conv2d, Linear, Sequential

from memsave_torch.nn import MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear

HEREDIR = path.dirname(path.abspath(__file__))
DATADIR = path.join(HEREDIR, "raw")
makedirs(DATADIR, exist_ok=True)


def main(
def main( # noqa: C901
architecture: str,
implementation: str,
mode: str,
num_layers: int,
requires_grad: str,
): # noqa: C901
):
"""Runs exps for generating the data of the visual abstract"""
manual_seed(0)

Expand Down

0 comments on commit 0df6e44

Please sign in to comment.