Skip to content

Commit

Permalink
Merge pull request #9 from plutonium-239/dropout
Browse files Browse the repository at this point in the history
merge dropout into main
  • Loading branch information
plutonium-239 authored Aug 22, 2024
2 parents 1fa42dc + 8f5c7e3 commit ea97d77
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
46 changes: 46 additions & 0 deletions memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Implementation of a memory saving Dropout (sort of).
This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit).
"""

import torch
import torch.nn as nn

from memsave_torch.nn.functional import dropoutMemSave


class MemSaveDropout(nn.Dropout):
"""MemSaveDropout."""

def __init__(self, p=0.5):
"""Inits a MemSaveDropout layer with the given params.
Args:
p: Probability of elements being zeroed
"""
super().__init__(p)

def forward(self, x):
"""Forward pass.
Args:
x: Input to the network
Returns:
torch.Tensor: Output
"""
return dropoutMemSave(x, self.p, self.train)

@classmethod
def from_nn_dropout(cls, dropout: nn.Dropout):
"""Converts a nn.Dropout layer to MemSaveDropout.
Args:
dropout : The nn.Dropout layer
Returns:
obj: The MemSaveDropout object
"""
obj = cls(dropout.p)
return obj

3 changes: 2 additions & 1 deletion memsave_torch/nn/functional/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def forward(ctx, x, p, train):
out = torch.dropout(x, p, train)
if ctx.needs_input_grad[0]:
ctx.p = p
ctx.mask = mask
ctx.train = train
ctx.rng = rng
return out

@staticmethod
Expand Down

0 comments on commit ea97d77

Please sign in to comment.