diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 3529230..3ef58ff 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -260,6 +260,12 @@ def estimate_mem_savings( help=f"Which case to run, allowed values are {allowed_cases}", ) parser.add_argument("--device", type=str, default="cpu", help="torch device name") + parser.add_argument( + "--bn_eval", + action="store_true", + default=False, + help="Set all BN layers to eval mode", + ) parser.add_argument( "--print", action="store_true", @@ -364,6 +370,10 @@ def estimate_mem_savings( loss_fn_orig = loss_fn loss_fn = lambda: models.SegmentationLossWrapper(loss_fn_orig) # noqa: E731 + if args.bn_eval: + model_fn_orig_bn = model_fn + model_fn = lambda: models.set_BN_to_eval(model_fn_orig_bn()) # noqa: E731 + # warm-up # with redirect_stdout(open(devnull, "w")): # estimate_speedup(model_fn, loss_fn, x, y, dev, vjp_speedups[:1]) diff --git a/experiments/util/models.py b/experiments/util/models.py index 3895cc1..8fcfe91 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -135,16 +135,20 @@ def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]: raise ValueError(f"arch={arch} not in allowed architectures") -def set_BN_to_eval(model: Module): - """Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()) in-place. +def set_BN_to_eval(model: Module) -> Module: + """Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()). Args: model (Module): Input model + + Returns: + Module: Model with BN layers in eval mode """ known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d) for layer in model.modules(): if isinstance(layer, known_bn_layers): layer.eval() + return model # CONV diff --git a/experiments/visual_abstract/run.py b/experiments/visual_abstract/run.py index 9dde6c9..906434c 100644 --- a/experiments/visual_abstract/run.py +++ b/experiments/visual_abstract/run.py @@ -6,23 +6,24 @@ from os import makedirs, path from memory_profiler import memory_usage -from memsave_torch.nn import MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear from torch import allclose, manual_seed, rand, rand_like from torch.autograd import grad from torch.nn import BatchNorm2d, Conv2d, Linear, Sequential +from memsave_torch.nn import MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear + HEREDIR = path.dirname(path.abspath(__file__)) DATADIR = path.join(HEREDIR, "raw") makedirs(DATADIR, exist_ok=True) -def main( +def main( # noqa: C901 architecture: str, implementation: str, mode: str, num_layers: int, requires_grad: str, -): # noqa: C901 +): """Runs exps for generating the data of the visual abstract""" manual_seed(0)