Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 16, 2024
1 parent dbb89d1 commit 79a5a68
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(self, x):
Returns:
torch.Tensor: Output
"""
return dropoutMemSave(x, self.p, self.train)
return dropoutMemSave(x, self.p, self.training)

@classmethod
def from_nn_dropout(cls, dropout: nn.Dropout):
Expand Down Expand Up @@ -65,15 +65,15 @@ def backward(ctx, grad_output):
return grad_x


def dropoutMemSave(x, p, train):
def dropoutMemSave(x, p, training):
"""Functional form of the memory saving dropout.
Args:
x: Input to the network
p: Probability of elements being zeroed
train: Whether the layer is in training mode (no dropout applied in eval)
training: Whether the layer is in training mode (no dropout applied in eval)
Returns:
torch.Tensor: Output of the network
"""
return _MemSaveDropout.apply(x, p, train)
return _MemSaveDropout.apply(x, p, training)

0 comments on commit 79a5a68

Please sign in to comment.