Skip to content

Commit

Permalink
dropout improvements
Browse files Browse the repository at this point in the history
(cherry picked from commit ea5e6cb)
  • Loading branch information
plutonium-239 committed Aug 3, 2024
1 parent 0b3b8c1 commit 9d293fe
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions memsave_torch/nn/functional/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def backward(ctx, grad_output):
mask = torch.empty_like(grad_output)
mask = mask.bernoulli_(0.5).bool()
torch.set_rng_state(orig_rng)
grad_x = torch.ops.aten.native_dropout_backward(
grad_output, mask, scale=1 / (1 - ctx.p)
)
grad_x = grad_output*mask/(1-ctx.p)
# grad_x = torch.ops.aten.native_dropout_backward(
# grad_output, mask, scale=1 / (1 - ctx.p)
# )

return grad_x, None, None

Expand Down

0 comments on commit 9d293fe

Please sign in to comment.