Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/llm' into llm
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 4, 2024
2 parents 1cb9d98 + 0df6e44 commit 962ac7e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
10 changes: 10 additions & 0 deletions experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def estimate_mem_savings(
help=f"Which case to run, allowed values are {allowed_cases}",
)
parser.add_argument("--device", type=str, default="cpu", help="torch device name")
parser.add_argument(
"--bn_eval",
action="store_true",
default=False,
help="Set all BN layers to eval mode",
)
parser.add_argument(
"--print",
action="store_true",
Expand Down Expand Up @@ -364,6 +370,10 @@ def estimate_mem_savings(
loss_fn_orig = loss_fn
loss_fn = lambda: models.SegmentationLossWrapper(loss_fn_orig) # noqa: E731

if args.bn_eval:
model_fn_orig_bn = model_fn
model_fn = lambda: models.set_BN_to_eval(model_fn_orig_bn()) # noqa: E731

# warm-up
# with redirect_stdout(open(devnull, "w")):
# estimate_speedup(model_fn, loss_fn, x, y, dev, vjp_speedups[:1])
Expand Down
8 changes: 6 additions & 2 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,20 @@ def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]:
raise ValueError(f"arch={arch} not in allowed architectures")


def set_BN_to_eval(model: Module):
"""Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()) in-place.
def set_BN_to_eval(model: Module) -> Module:
"""Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()).
Args:
model (Module): Input model
Returns:
Module: Model with BN layers in eval mode
"""
known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d)
for layer in model.modules():
if isinstance(layer, known_bn_layers):
layer.eval()
return model


# CONV
Expand Down
4 changes: 2 additions & 2 deletions experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
makedirs(DATADIR, exist_ok=True)


def main(
def main( # noqa: C901
architecture: str,
implementation: str,
mode: str,
num_layers: int,
requires_grad: str,
): # noqa: C901
):
"""Runs exps for generating the data of the visual abstract"""
manual_seed(0)

Expand Down

0 comments on commit 962ac7e

Please sign in to comment.