diff --git a/memsave_torch/nn/BatchNorm.py b/memsave_torch/nn/BatchNorm.py index a8de50c..79a1507 100644 --- a/memsave_torch/nn/BatchNorm.py +++ b/memsave_torch/nn/BatchNorm.py @@ -79,4 +79,5 @@ def from_nn_BatchNorm2d(cls, bn2d: nn.BatchNorm2d): obj.bias = bn2d.bias obj.running_mean = bn2d.running_mean obj.running_var = bn2d.running_var + obj.training = bn2d.training return obj diff --git a/memsave_torch/nn/functional/BatchNorm.py b/memsave_torch/nn/functional/BatchNorm.py index 33efcf0..934b6ef 100644 --- a/memsave_torch/nn/functional/BatchNorm.py +++ b/memsave_torch/nn/functional/BatchNorm.py @@ -38,8 +38,12 @@ def forward( need_grad = [] # save_mean and save_invstd if ctx.needs_input_grad[0]: need_grad.append(weight) - if any(ctx.needs_input_grad): - need_grad.append(x) + if ctx.training: + if any(ctx.needs_input_grad): + need_grad.append(x) + else: + if ctx.needs_input_grad[3]: + need_grad.append(x) # bias doesnt need anything for calc ctx.save_for_backward(*need_grad) @@ -54,7 +58,8 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 - x = ctx.saved_tensors[current_idx] + if ctx.training: + x = ctx.saved_tensors[current_idx] if ctx.needs_input_grad[3]: x = ctx.saved_tensors[current_idx]