diff --git a/experiments/util/models.py b/experiments/util/models.py index 8623a98..3895cc1 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -9,6 +9,7 @@ import torch import torchvision.models as tvm from torch.nn import ( + BatchNorm2d, Conv2d, Flatten, Linear, @@ -32,6 +33,7 @@ from transformers import utils as tf_utils from memsave_torch.nn import ( + MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear, convert_to_memory_saving, @@ -133,6 +135,18 @@ def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]: raise ValueError(f"arch={arch} not in allowed architectures") +def set_BN_to_eval(model: Module): + """Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()) in-place. + + Args: + model (Module): Input model + """ + known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d) + for layer in model.modules(): + if isinstance(layer, known_bn_layers): + layer.eval() + + # CONV conv_input_shape: Tuple[int, int, int] = (1, 1, 1)