From 3ba595884220153128ce1f17fd8b347f221b045d Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Tue, 23 Apr 2024 00:37:05 +0530 Subject: [PATCH] replace `transformers.Conv1D` with `MemSaveLinear` --- memsave_torch/nn/Linear.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index fbace35..cf869ba 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -35,14 +35,18 @@ def forward(self, x): @classmethod def from_nn_Linear(cls, linear: nn.Linear): - """Converts a nn.Linear layer to MemSaveLinear. + """Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear. Args: - linear : The nn.Linear layer + linear : The nn.Linear/transformers.Conv1D layer Returns: obj: The MemSaveLinear object """ + if linear.__class__ == 'transformers.pytorch_utils.Conv1D': + # it only saves output features in the model (linear.nf); need to take input features from weight anyway + # weight and bias are still defined + linear.in_features, linear.out_features = linear.weight.shape obj = cls( linear.in_features, linear.out_features,