Skip to content

Commit

Permalink
format + add conv3d in converter
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Aug 22, 2024
1 parent a76ccd8 commit 330ae05
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 2 additions & 0 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchvision.models.convnext import LayerNorm2d
from transformers import Conv1D

from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.Linear import MemSaveLinear

Expand Down Expand Up @@ -317,6 +318,7 @@ def separate_grad_arguments(
LayerNorm,
LayerNorm2d,
MemSaveBatchNorm2d,
)
embed = Embedding

leafs, no_leafs = [], []
Expand Down
1 change: 0 additions & 1 deletion memsave_torch/nn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit).
"""

import torch
import torch.nn as nn

from memsave_torch.nn.functional import dropoutMemSave
Expand Down
15 changes: 11 additions & 4 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def convert_to_memory_saving(
model: nn.Module,
linear=True,
conv2d=True,
conv1d=False,
conv1d=True,
conv3d=True,
batchnorm2d=True,
relu=True,
maxpool2d=True,
Expand All @@ -53,6 +54,7 @@ def convert_to_memory_saving(
linear (bool, optional): Whether to replace `nn.Linear` layers
conv2d (bool, optional): Whether to replace `nn.Conv2d` layers
conv1d (bool, optional): Whether to replace `nn.Conv1d` layers
conv3d (bool, optional): Whether to replace `nn.Conv3d` layers
batchnorm2d (bool, optional): Whether to replace `nn.BatchNorm2d` layers
relu (bool, optional): Whether to replace `nn.ReLU` layers
maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers
Expand All @@ -78,15 +80,20 @@ def convert_to_memory_saving(
"cls": nn.MaxPool2d,
"convert_fn": MemSaveMaxPool2d.from_nn_MaxPool2d,
},
{
"allowed": conv1d,
"cls": nn.Conv1d,
"convert_fn": MemSaveConv1d.from_nn_Conv1d,
},
{
"allowed": conv2d,
"cls": nn.Conv2d,
"convert_fn": MemSaveConv2d.from_nn_Conv2d,
},
{
"allowed": conv1d,
"cls": nn.Conv1d,
"convert_fn": MemSaveConv1d.from_nn_Conv1d,
"allowed": conv3d,
"cls": nn.Conv3d,
"convert_fn": MemSaveConv3d.from_nn_Conv3d,
},
{
"allowed": conv1d,
Expand Down

0 comments on commit 330ae05

Please sign in to comment.