Skip to content

Commit

Permalink
replace transformers.Conv1D with MemSaveLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 22, 2024
1 parent bba6a11 commit dd48758
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import torch.nn as nn
import transformers

from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
Expand Down Expand Up @@ -56,7 +57,7 @@ def convert_to_memory_saving(
layers = [
{
"allowed": linear,
"cls": nn.Linear,
"cls": (nn.Linear, transformers.Conv1D),
"convert_fn": MemSaveLinear.from_nn_Linear,
},
{"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU},
Expand Down

0 comments on commit dd48758

Please sign in to comment.