Skip to content

Commit

Permalink
batchnorm in eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Jul 4, 2024
1 parent 8e68a88 commit edd9b03
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions memsave_torch/nn/BatchNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ def from_nn_BatchNorm2d(cls, bn2d: nn.BatchNorm2d):
obj.bias = bn2d.bias
obj.running_mean = bn2d.running_mean
obj.running_var = bn2d.running_var
obj.training = bn2d.training
return obj
11 changes: 8 additions & 3 deletions memsave_torch/nn/functional/BatchNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def forward(
need_grad = [] # save_mean and save_invstd
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if any(ctx.needs_input_grad):
need_grad.append(x)
if ctx.training:
if any(ctx.needs_input_grad):
need_grad.append(x)
else:
if ctx.needs_input_grad[3]:
need_grad.append(x)
# bias doesnt need anything for calc

ctx.save_for_backward(*need_grad)
Expand All @@ -54,7 +58,8 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
x = ctx.saved_tensors[current_idx]
if ctx.training:
x = ctx.saved_tensors[current_idx]
if ctx.needs_input_grad[3]:
x = ctx.saved_tensors[current_idx]

Expand Down

0 comments on commit edd9b03

Please sign in to comment.