Skip to content

Commit

Permalink
add helper function to set BN layers to eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Jul 4, 2024
1 parent edd9b03 commit 2c5afa1
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torchvision.models as tvm
from torch.nn import (
BatchNorm2d,
Conv2d,
Flatten,
Linear,
Expand All @@ -32,6 +33,7 @@
from transformers import utils as tf_utils

from memsave_torch.nn import (
MemSaveBatchNorm2d,
MemSaveConv2d,
MemSaveLinear,
convert_to_memory_saving,
Expand Down Expand Up @@ -133,6 +135,18 @@ def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]:
raise ValueError(f"arch={arch} not in allowed architectures")


def set_BN_to_eval(model: Module):
"""Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()) in-place.
Args:
model (Module): Input model
"""
known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d)
for layer in model.modules():
if isinstance(layer, known_bn_layers):
layer.eval()


# CONV
conv_input_shape: Tuple[int, int, int] = (1, 1, 1)

Expand Down

0 comments on commit 2c5afa1

Please sign in to comment.