Skip to content

Commit

Permalink
replace transformers.Conv1D with MemSaveLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 22, 2024
1 parent dd48758 commit 3ba5958
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions memsave_torch/nn/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def forward(self, x):

@classmethod
def from_nn_Linear(cls, linear: nn.Linear):
"""Converts a nn.Linear layer to MemSaveLinear.
"""Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear.
Args:
linear : The nn.Linear layer
linear : The nn.Linear/transformers.Conv1D layer
Returns:
obj: The MemSaveLinear object
"""
if linear.__class__ == 'transformers.pytorch_utils.Conv1D':
# it only saves output features in the model (linear.nf); need to take input features from weight anyway
# weight and bias are still defined
linear.in_features, linear.out_features = linear.weight.shape
obj = cls(
linear.in_features,
linear.out_features,
Expand Down

0 comments on commit 3ba5958

Please sign in to comment.