diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 62747c2..991f0c3 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -29,6 +29,7 @@ from torchvision.models.convnext import LayerNorm2d from transformers import Conv1D +from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv2d import MemSaveConv2d from memsave_torch.nn.Linear import MemSaveLinear @@ -317,6 +318,7 @@ def separate_grad_arguments( LayerNorm, LayerNorm2d, MemSaveBatchNorm2d, + ) embed = Embedding leafs, no_leafs = [], [] diff --git a/memsave_torch/nn/Dropout.py b/memsave_torch/nn/Dropout.py index 5d682e5..fd0317d 100644 --- a/memsave_torch/nn/Dropout.py +++ b/memsave_torch/nn/Dropout.py @@ -3,7 +3,6 @@ This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit). """ -import torch import torch.nn as nn from memsave_torch.nn.functional import dropoutMemSave diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index 5619d8d..e1d65ef 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -34,7 +34,8 @@ def convert_to_memory_saving( model: nn.Module, linear=True, conv2d=True, - conv1d=False, + conv1d=True, + conv3d=True, batchnorm2d=True, relu=True, maxpool2d=True, @@ -53,6 +54,7 @@ def convert_to_memory_saving( linear (bool, optional): Whether to replace `nn.Linear` layers conv2d (bool, optional): Whether to replace `nn.Conv2d` layers conv1d (bool, optional): Whether to replace `nn.Conv1d` layers + conv3d (bool, optional): Whether to replace `nn.Conv3d` layers batchnorm2d (bool, optional): Whether to replace `nn.BatchNorm2d` layers relu (bool, optional): Whether to replace `nn.ReLU` layers maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers @@ -78,15 +80,20 @@ def convert_to_memory_saving( "cls": nn.MaxPool2d, "convert_fn": MemSaveMaxPool2d.from_nn_MaxPool2d, }, + { + "allowed": conv1d, + "cls": nn.Conv1d, + "convert_fn": MemSaveConv1d.from_nn_Conv1d, + }, { "allowed": conv2d, "cls": nn.Conv2d, "convert_fn": MemSaveConv2d.from_nn_Conv2d, }, { - "allowed": conv1d, - "cls": nn.Conv1d, - "convert_fn": MemSaveConv1d.from_nn_Conv1d, + "allowed": conv3d, + "cls": nn.Conv3d, + "convert_fn": MemSaveConv3d.from_nn_Conv3d, }, { "allowed": conv1d,