Skip to content

Commit

Permalink
dropout fix backward return + move into functional
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed May 1, 2024
1 parent 3bfde12 commit f64629a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 36 deletions.
37 changes: 2 additions & 35 deletions memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
import torch.nn as nn

from memsave_torch.nn.functional import dropoutMemSave


class MemSaveDropout(nn.Dropout):
"""MemSaveDropout."""
Expand Down Expand Up @@ -42,38 +44,3 @@ def from_nn_dropout(cls, dropout: nn.Dropout):
obj = cls(dropout.p)
return obj


# TODO: inplace
class _MemSaveDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p, train):
out, mask = torch.ops.aten.native_dropout(x, p, train)
if ctx.needs_input_grad[0]:
ctx.p = p
ctx.mask = mask
return out

@staticmethod
def backward(ctx, grad_output):
grad_x = None

if ctx.needs_input_grad[0]:
grad_x = torch.ops.aten.native_dropout_backward(
grad_output, ctx.mask, scale=1 / (1 - ctx.p)
)

return grad_x


def dropoutMemSave(x, p, training):
"""Functional form of the memory saving dropout.
Args:
x: Input to the network
p: Probability of elements being zeroed
training: Whether the layer is in training mode (no dropout applied in eval)
Returns:
torch.Tensor: Output of the network
"""
return _MemSaveDropout.apply(x, p, training)
2 changes: 1 addition & 1 deletion memsave_torch/nn/functional/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def backward(ctx, grad_output):
grad_output, ctx.mask, scale=1 / (1 - ctx.p)
)

return grad_x
return grad_x, None, None


def dropoutMemSave(x, p, training) -> torch.Tensor:
Expand Down

0 comments on commit f64629a

Please sign in to comment.