-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from plutonium-239/dropout
merge dropout into main
- Loading branch information
Showing
2 changed files
with
48 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Implementation of a memory saving Dropout (sort of). | ||
This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit). | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from memsave_torch.nn.functional import dropoutMemSave | ||
|
||
|
||
class MemSaveDropout(nn.Dropout): | ||
"""MemSaveDropout.""" | ||
|
||
def __init__(self, p=0.5): | ||
"""Inits a MemSaveDropout layer with the given params. | ||
Args: | ||
p: Probability of elements being zeroed | ||
""" | ||
super().__init__(p) | ||
|
||
def forward(self, x): | ||
"""Forward pass. | ||
Args: | ||
x: Input to the network | ||
Returns: | ||
torch.Tensor: Output | ||
""" | ||
return dropoutMemSave(x, self.p, self.train) | ||
|
||
@classmethod | ||
def from_nn_dropout(cls, dropout: nn.Dropout): | ||
"""Converts a nn.Dropout layer to MemSaveDropout. | ||
Args: | ||
dropout : The nn.Dropout layer | ||
Returns: | ||
obj: The MemSaveDropout object | ||
""" | ||
obj = cls(dropout.p) | ||
return obj | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters