Skip to content

Commit

Permalink
more formatting, ready for new version
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Aug 22, 2024
1 parent 330ae05 commit 7324da6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`)
n_repeat = 5
batchnorm_eval = True # BatchNorm in eval mode
batchnorm_eval = True # BatchNorm in eval mode

# ============== CONV CONFIG ==============
# Valid choices for models are in models.conv_model_fns
Expand Down
8 changes: 6 additions & 2 deletions experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@ def skip_case_check(args: argparse.Namespace) -> bool:
invalid = False
if args.case is None:
return invalid
is_surgical = 'surgical_last' in args.case or 'surgical_first' in args.case
is_surgical = "surgical_last" in args.case or "surgical_first" in args.case
# 1.
for c in ["grad_norm_bias", "grad_norm_weights"]:
if c in args.case and args.model in models.models_without_norm:
invalid = True
for c in ["no_grad_norm_bias", "no_grad_norm_weights"]:
if c not in args.case and args.model in models.models_without_norm and not is_surgical:
if (
c not in args.case
and args.model in models.models_without_norm
and not is_surgical
):
invalid = True
# 2.
if "no_grad_embed_weights" in args.case and "grad_input" not in args.case:
Expand Down
1 change: 0 additions & 1 deletion memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,3 @@ def from_nn_dropout(cls, dropout: nn.Dropout):
"""
obj = cls(dropout.p)
return obj

0 comments on commit 7324da6

Please sign in to comment.