From 79a5a688e5b9ba15dc8062d9f39c867987b592f8 Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Tue, 16 Apr 2024 17:30:22 +0530 Subject: [PATCH] minor fix --- memsave_torch/nn/Dropout.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/memsave_torch/nn/Dropout.py b/memsave_torch/nn/Dropout.py index b78c7a3..a999361 100644 --- a/memsave_torch/nn/Dropout.py +++ b/memsave_torch/nn/Dropout.py @@ -27,7 +27,7 @@ def forward(self, x): Returns: torch.Tensor: Output """ - return dropoutMemSave(x, self.p, self.train) + return dropoutMemSave(x, self.p, self.training) @classmethod def from_nn_dropout(cls, dropout: nn.Dropout): @@ -65,15 +65,15 @@ def backward(ctx, grad_output): return grad_x -def dropoutMemSave(x, p, train): +def dropoutMemSave(x, p, training): """Functional form of the memory saving dropout. Args: x: Input to the network p: Probability of elements being zeroed - train: Whether the layer is in training mode (no dropout applied in eval) + training: Whether the layer is in training mode (no dropout applied in eval) Returns: torch.Tensor: Output of the network """ - return _MemSaveDropout.apply(x, p, train) + return _MemSaveDropout.apply(x, p, training)